diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 35bb01eae..9373ba626 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -114,7 +114,7 @@ jobs: - name: Haddock # Behaviour of cabal haddock has changed for the worse: https://github.com/haskell/cabal/issues/8725 run: cabal haddock --disable-documentation - if: matrix.mode == 'release' + if: matrix.os != 'windows-latest' && matrix.mode == 'release' - name: Test doctest run: cabal test doctest diff --git a/.gitignore b/.gitignore index 922dbbc64..5e6b52e39 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ /docs/_build *.hi *.o + +hie.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index 53ff0a1a7..c83e3ae25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,29 @@ Policy (PVP)](https://pvp.haskell.org) ## [next] ### Added * Added debugging functions in module `Data.Array.Accelerate.Debug.Trace` ([#485](https://github.com/AccelerateHS/accelerate/pull/485)) + * Support for SIMD data types in expressions. Support for storing a type `a` + in a SIMD vector can be added by deriving an instance for the class `SIMD`. + Pattern synonyms `V2`, `V3`, `V4`, `V8` and `V16` are provided to work with + these at both the Haskell value and embedded expression level. + * Instances for SIMD types in basic numeric classes (e.g. `Num` for `<4 x Float>`) + * Support for 128-bit integers (signed and unsigned) + * Support for 128-bit floating point types (build with cabal flag `float128`) ### Changed * Removed dependency on lens ([#493](https://github.com/AccelerateHS/accelerate/pull/493)) + * The shape constructors (e.g. `Z` and `(:.)`) are now pattern synonyms that + work on both Haskell values and embedded expressions Similarly for the + constructors of `Maybe`, `Either`, `Bool`, and `Ordering`. ### Fixed * Graphviz graph generation of `-ddump-dot` and `-ddump-simpl-dot` ([#384](https://github.com/AccelerateHS/accelerate/issues/384)) + * Bug in `Semigroup` instance for `Maybe` ([#517](https://github.com/AccelerateHS/accelerate/issues/517)) + * Bug in `Ord` instances or tuple types + +### Removed + * Pattern synonyms `Z_`, `(::.)`, `Any_`, `All_`, which are no longer required + * Pattern synonyms `Just_`, `Nothing_` etc., which have been renamed to no + longer require the trailing underscore. ### Contributors diff --git a/accelerate.cabal b/accelerate.cabal index c3ed7ec3f..5228911e3 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -207,6 +207,15 @@ custom-setup , directory >= 1.0 , filepath >= 1.0 +flag float128 + manual: True + default: False + description: + Enable support for 128-bit floating point numbers + . + This requires the library 'quadmath' to be installed. Note that not all + targets support 128-bit floating-point numbers. + flag debug manual: True default: False @@ -364,6 +373,11 @@ library , unique , unordered-containers >= 0.2 , vector >= 0.10 + , wide-word >= 0.1 + + if impl(ghc < 9.0) + build-depends: + integer-gmp exposed-modules: -- The core language and reference implementation @@ -392,14 +406,15 @@ library Data.Array.Accelerate.Analysis.Hash Data.Array.Accelerate.Analysis.Match Data.Array.Accelerate.Array.Data - Data.Array.Accelerate.Array.Remote - Data.Array.Accelerate.Array.Remote.Class - Data.Array.Accelerate.Array.Remote.LRU - Data.Array.Accelerate.Array.Remote.Table + -- Data.Array.Accelerate.Array.Remote + -- Data.Array.Accelerate.Array.Remote.Class + -- Data.Array.Accelerate.Array.Remote.LRU + -- Data.Array.Accelerate.Array.Remote.Table Data.Array.Accelerate.Array.Unique Data.Array.Accelerate.Async - Data.Array.Accelerate.Error Data.Array.Accelerate.Debug.Internal + Data.Array.Accelerate.Error + Data.Array.Accelerate.Interpreter.Arithmetic Data.Array.Accelerate.Lifetime Data.Array.Accelerate.Pretty Data.Array.Accelerate.Representation.Array @@ -433,9 +448,11 @@ library Data.Array.Accelerate.Test.Similar -- Other + Crypto.Hash.XKCP Data.BitSet + Data.Primitive.Bit Data.Primitive.Vec - Crypto.Hash.XKCP + Data.Numeric.Float128 other-modules: Data.Array.Accelerate.Analysis.Hash.TH @@ -445,6 +462,7 @@ library Data.Array.Accelerate.Classes.Eq Data.Array.Accelerate.Classes.Floating Data.Array.Accelerate.Classes.Fractional + Data.Array.Accelerate.Classes.FromBool Data.Array.Accelerate.Classes.FromIntegral Data.Array.Accelerate.Classes.Integral Data.Array.Accelerate.Classes.Num @@ -454,6 +472,9 @@ library Data.Array.Accelerate.Classes.RealFloat Data.Array.Accelerate.Classes.RealFrac Data.Array.Accelerate.Classes.ToFloating + Data.Array.Accelerate.Classes.VEq + Data.Array.Accelerate.Classes.VNum + Data.Array.Accelerate.Classes.VOrd Data.Array.Accelerate.Debug.Internal.Clock Data.Array.Accelerate.Debug.Internal.Flags Data.Array.Accelerate.Debug.Internal.Graph @@ -470,7 +491,10 @@ library Data.Array.Accelerate.Pattern.Either Data.Array.Accelerate.Pattern.Maybe Data.Array.Accelerate.Pattern.Ordering + Data.Array.Accelerate.Pattern.SIMD + Data.Array.Accelerate.Pattern.Shape Data.Array.Accelerate.Pattern.TH + Data.Array.Accelerate.Pattern.Tuple Data.Array.Accelerate.Prelude Data.Array.Accelerate.Pretty.Graphviz Data.Array.Accelerate.Pretty.Graphviz.Monad @@ -489,6 +513,7 @@ library Data.Array.Accelerate.Test.NoFib.Config Language.Haskell.TH.Extra + GHC.TypeLits.Extra if flag(nofib) build-depends: @@ -562,6 +587,7 @@ library cc-options: -O3 -Wall + -std=c11 cxx-options: -O3 @@ -590,6 +616,16 @@ library -caf-all -auto-all + if flag(float128) + cc-options: + -DFLOAT128_ENABLE + + cpp-options: + -DFLOAT128_ENABLE + + extra-libraries: + quadmath + if flag(debug) cc-options: -DACCELERATE_DEBUG diff --git a/cbits/float128.c b/cbits/float128.c new file mode 100644 index 000000000..f16aad747 --- /dev/null +++ b/cbits/float128.c @@ -0,0 +1,119 @@ + +#include +#include + +typedef _Float128 f128; + +union ieee754_quad { + f128 as_float128; + struct { +#if WORDS_BIGENDIAN + uint64_t negative:1; + uint64_t exponent:15; + uint64_t mantissa0:48; + uint64_t mantissa1; +#else + uint64_t mantissa1; + uint64_t mantissa0:48; + uint64_t exponent:15; + uint64_t negative:1; +#endif + } as_uint128; +}; + +/* Operations from Read and Show + */ +void _readq(f128* r, const char* str) { *r = strtoflt128(str, NULL); } +void _showq(char* buf, size_t n, f128 *a) { quadmath_snprintf(buf, n, "%Qf", *a); } + +/* Operations from Num + */ +void _addq(f128* r, const f128* a, const f128* b) { *r = *a + *b; } +void _subq(f128* r, const f128* a, const f128* b) { *r = *a - *b; } +void _mulq(f128* r, const f128* a, const f128* b) { *r = *a * *b; } +void _negateq(f128* r, const f128* a) { *r = - *a; } +void _absq(f128* r, const f128* a) { *r = fabsq(*a); } +void _signumq(f128* r, const f128* a) { *r = (*a > 0.0q) - (*a < 0.0q); } + +/* Operations from Fractional + */ +void _divq(f128* r, const f128* a, const f128* b) { *r = *a / *b; } +void _recipq(f128* r, const f128* a) { *r = 1.0q / *a; } + +/* Operations from Floating + */ +void _piq(f128* r) { *r = M_PIq; } +void _expq(f128* r, const f128* a) { *r = expq(*a); } +void _logq(f128* r, const f128* a) { *r = logq(*a); } +void _sqrtq(f128* r, const f128* a) { *r = sqrtq(*a); } +void _powq(f128* r, const f128* a, const f128* b) { *r = powq(*a, *b); } +void _sinq(f128* r, const f128* a) { *r = sinq(*a); } +void _cosq(f128* r, const f128* a) { *r = cosq(*a); } +void _tanq(f128* r, const f128* a) { *r = tanq(*a); } +void _asinq(f128* r, const f128* a) { *r = asinq(*a); } +void _acosq(f128* r, const f128* a) { *r = acosq(*a); } +void _atanq(f128* r, const f128* a) { *r = atanq(*a); } +void _sinhq(f128* r, const f128* a) { *r = sinhq(*a); } +void _coshq(f128* r, const f128* a) { *r = coshq(*a); } +void _tanhq(f128* r, const f128* a) { *r = tanhq(*a); } +void _asinhq(f128* r, const f128* a) { *r = asinhq(*a); } +void _acoshq(f128* r, const f128* a) { *r = acoshq(*a); } +void _atanhq(f128* r, const f128* a) { *r = atanhq(*a); } +void _log1pq(f128* r, const f128* a) { *r = log1pq(*a); } +void _expm1q(f128* r, const f128* a) { *r = expm1q(*a); } + +/* Operations from RealFrac + */ +void _roundq(f128* r, const f128* a) { *r = roundq(*a); } +void _truncq(f128* r, const f128* a) { *r = truncq(*a); } +void _floorq(f128* r, const f128* a) { *r = floorq(*a); } +void _ceilq(f128* r, const f128* a) { *r = ceilq(*a); } + +/* Operations from RealFloat + */ +uint32_t _isnanq(const f128* a) { return isnanq(*a); } +uint32_t _isinfq(const f128* a) { return isinfq(*a); } +void _frexpq(f128* r, const f128* a, int32_t* b) { *r = frexpq(*a, b); } +void _ldexpq(f128* r, const f128* a, int32_t b) { *r = ldexpq(*a, b); } +void _atan2q(f128* r, const f128* a, const f128* b) { *r = atan2q(*a, *b); } + +/* A (single/double/quad) precision floating point number is denormalized iff: + * - exponent is zero + * - mantissa is non-zero + * - (don't care about the sign bit) + */ +uint32_t _isdenormq(const f128* a) +{ + union ieee754_quad u; + u.as_float128 = *a; + + return (u.as_uint128.exponent == 0 + && (u.as_uint128.mantissa0 != 0 || u.as_uint128.mantissa1 != 0)); +} + +/* A (single/double/quad) precision floating point number is negative zero iff: + * - sign bit is set + * - all other bits are zero + */ +uint32_t _isnegzeroq(const f128* a) +{ + union ieee754_quad u; + u.as_float128 = *a; + + return ( + u.as_uint128.negative && + u.as_uint128.exponent == 0 && + u.as_uint128.mantissa0 == 0 && + u.as_uint128.mantissa1 == 0 + ); +} + +/* Operations from Ord + */ +uint32_t _ltq(const f128* a, const f128* b) { return *a < *b; } +uint32_t _leq(const f128* a, const f128* b) { return *a <= *b; } +uint32_t _gtq(const f128* a, const f128* b) { return *a > *b; } +uint32_t _geq(const f128* a, const f128* b) { return *a <= *b; } +void _fminq(f128* r, const f128* a, const f128* b) { *r = fminq(*a, *b); } +void _fmaxq(f128* r, const f128* a, const f128* b) { *r = fmaxq(*a, *b); } + diff --git a/icebox/Interpreter.hs b/icebox/Interpreter.hs new file mode 100644 index 000000000..b3733425d --- /dev/null +++ b/icebox/Interpreter.hs @@ -0,0 +1,480 @@ +-- | +-- Module : Data.Array.Accelerate.Interpreter +-- Description : Reference backend (interpreted) +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + + +-- | Stream a lazily read list of input arrays through the given program, +-- collecting results as we go +-- +streamOut :: Arrays a => Sugar.Seq [a] -> [a] +streamOut seq = let seq' = convertSeqWith config seq + in evalDelayedSeq defaultSeqConfig seq' + + +toSeqOp :: forall slix sl dim co e proxy. (Elt slix, Shape sl, Shape dim, Elt e) + => SliceIndex (EltRepr slix) + (EltRepr sl) + co + (EltRepr dim) + -> proxy slix + -> Array dim e + -> [Array sl e] +toSeqOp sliceIndex _ arr = map (sliceOp sliceIndex arr :: slix -> Array sl e) + (enumSlices sliceIndex (shape arr)) + +-- Sequence evaluation +-- --------------- + +-- Position in sequence. +-- +type SeqPos = Int + +-- Configuration for sequence evaluation. +-- +data SeqConfig = SeqConfig + { chunkSize :: Int -- Allocation limit for a sequence in + -- words. Actual runtime allocation should be the + -- maximum of this size and the size of the + -- largest element in the sequence. + } + +-- Default sequence evaluation configuration for testing purposes. +-- +defaultSeqConfig :: SeqConfig +defaultSeqConfig = SeqConfig { chunkSize = 2 } + +type Chunk a = Vector' a + +-- The empty chunk. O(1). +emptyChunk :: Arrays a => Chunk a +emptyChunk = empty' + +-- Number of arrays in chunk. O(1). +-- +clen :: Arrays a => Chunk a -> Int +clen = length' + +elemsPerChunk :: SeqConfig -> Int -> Int +elemsPerChunk conf n + | n < 1 = chunkSize conf + | otherwise = + let (a,b) = chunkSize conf `quotRem` n + in a + signum b + +-- Drop a number of arrays from a chunk. O(1). Note: Require keeping a +-- scan of element sizes. +-- +cdrop :: Arrays a => Int -> Chunk a -> Chunk a +cdrop = drop' dropOp (fst . offsetsOp) + +-- Get all the shapes of a chunk of arrays. O(1). +-- +chunkShapes :: Chunk (Array sh a) -> Vector sh +chunkShapes = shapes' + +-- Get all the elements of a chunk of arrays. O(1). +-- +chunkElems :: Chunk (Array sh a) -> Vector a +chunkElems = elements' + +-- Convert a vector to a chunk of scalars. +-- +vec2Chunk :: Elt e => Vector e -> Chunk (Scalar e) +vec2Chunk = vec2Vec' + +-- Convert a list of arrays to a chunk. +-- +fromListChunk :: Arrays a => [a] -> Vector' a +fromListChunk = fromList' concatOp + +-- Convert a chunk to a list of arrays. +-- +toListChunk :: Arrays a => Vector' a -> [a] +toListChunk = toList' fetchAllOp + +-- fmap for Chunk. O(n). +-- TODO: Use vectorised function. +mapChunk :: (Arrays a, Arrays b) + => (a -> b) + -> Chunk a -> Chunk b +mapChunk f c = fromListChunk $ map f (toListChunk c) + +-- zipWith for Chunk. O(n). +-- TODO: Use vectorised function. +zipWithChunk :: (Arrays a, Arrays b, Arrays c) + => (a -> b -> c) + -> Chunk a -> Chunk b -> Chunk c +zipWithChunk f c1 c2 = fromListChunk $ zipWith f (toListChunk c1) (toListChunk c2) + +-- A window on a sequence. +-- +data Window a = Window + { chunk :: Chunk a -- Current allocated chunk. + , wpos :: SeqPos -- Position of the window on the sequence, given + -- in number of elements. + } + +-- The initial empty window. +-- +window0 :: Arrays a => Window a +window0 = Window { chunk = emptyChunk, wpos = 0 } + +-- Index the given window by the given index on the sequence. +-- +(!#) :: Arrays a => Window a -> SeqPos -> Chunk a +w !# i + | j <- i - wpos w + , j >= 0 + = cdrop j (chunk w) + -- + | otherwise + = error $ "Window indexed before position. wpos = " ++ show (wpos w) ++ " i = " ++ show i + +-- Move the give window by supplying the next chunk. +-- +moveWin :: Arrays a => Window a -> Chunk a -> Window a +moveWin w c = w { chunk = c + , wpos = wpos w + clen (chunk w) + } + +-- A cursor on a sequence. +-- +data Cursor senv a = Cursor + { ref :: Idx senv a -- Reference to the sequence. + , cpos :: SeqPos -- Position of the cursor on the sequence, + -- given in number of elements. + } + +-- Initial cursor. +-- +cursor0 :: Idx senv a -> Cursor senv a +cursor0 x = Cursor { ref = x, cpos = 0 } + +-- Advance cursor by a relative amount. +-- +moveCursor :: Int -> Cursor senv a -> Cursor senv a +moveCursor k c = c { cpos = cpos c + k } + +-- Valuation for an environment of sequence windows. +-- +data Val' senv where + Empty' :: Val' () + Push' :: Val' senv -> Window t -> Val' (senv, t) + +-- Projection of a window from a window valuation using a de Bruijn +-- index. +-- +prj' :: Idx senv t -> Val' senv -> Window t +prj' ZeroIdx (Push' _ v) = v +prj' (SuccIdx idx) (Push' val _) = prj' idx val + +-- Projection of a chunk from a window valuation using a sequence +-- cursor. +-- +prjChunk :: Arrays a => Cursor senv a -> Val' senv -> Chunk a +prjChunk c senv = prj' (ref c) senv !# cpos c + +-- An executable sequence. +-- +data ExecSeq senv arrs where + ExecP :: Arrays a => Window a -> ExecP senv a -> ExecSeq (senv, a) arrs -> ExecSeq senv arrs + ExecC :: Arrays a => ExecC senv a -> ExecSeq senv a + ExecR :: Arrays a => Cursor senv a -> ExecSeq senv [a] + +-- An executable producer. +-- +data ExecP senv a where + ExecStreamIn :: Int + -> [a] + -> ExecP senv a + + ExecMap :: Arrays a + => (Chunk a -> Chunk b) + -> Cursor senv a + -> ExecP senv b + + ExecZipWith :: (Arrays a, Arrays b) + => (Chunk a -> Chunk b -> Chunk c) + -> Cursor senv a + -> Cursor senv b + -> ExecP senv c + + -- Stream scan skeleton. + ExecScan :: Arrays a + => (s -> Chunk a -> (Chunk r, s)) -- Chunk scanner. + -> s -- Accumulator (internal state). + -> Cursor senv a -- Input stream. + -> ExecP senv r + +-- An executable consumer. +-- +data ExecC senv a where + + -- Stream reduction skeleton. + ExecFold :: Arrays a + => (s -> Chunk a -> s) -- Chunk consumer function. + -> (s -> r) -- Finalizer function. + -> s -- Accumulator (internal state). + -> Cursor senv a -- Input stream. + -> ExecC senv r + + ExecStuple :: IsAtuple a + => Atuple (ExecC senv) (TupleRepr a) + -> ExecC senv a + +minCursor :: ExecSeq senv a -> SeqPos +minCursor s = travS s 0 + where + travS :: ExecSeq senv a -> Int -> SeqPos + travS s i = + case s of + ExecP _ p s' -> travP p i `min` travS s' (i+1) + ExecC c -> travC c i + ExecR _ -> maxBound + + k :: Cursor senv a -> Int -> SeqPos + k c i + | i == idxToInt (ref c) = cpos c + | otherwise = maxBound + + travP :: ExecP senv a -> Int -> SeqPos + travP p i = + case p of + ExecStreamIn _ _ -> maxBound + ExecMap _ c -> k c i + ExecZipWith _ c1 c2 -> k c1 i `min` k c2 i + ExecScan _ _ c -> k c i + + travT :: Atuple (ExecC senv) t -> Int -> SeqPos + travT NilAtup _ = maxBound + travT (SnocAtup t c) i = travT t i `min` travC c i + + travC :: ExecC senv a -> Int -> SeqPos + travC c i = + case c of + ExecFold _ _ _ cu -> k cu i + ExecStuple t -> travT t i + + +evalDelayedSeq + :: SeqConfig + -> DelayedSeq arrs + -> arrs +evalDelayedSeq cfg (DelayedSeq aenv s) | aenv' <- evalExtend aenv Empty + = evalSeq cfg s aenv' + +evalSeq :: forall aenv arrs. + SeqConfig + -> PreOpenSeq DelayedOpenAcc aenv () arrs + -> Val aenv -> arrs +evalSeq conf s aenv = evalSeq' s + where + evalSeq' :: PreOpenSeq DelayedOpenAcc aenv senv arrs -> arrs + evalSeq' (Producer _ s) = evalSeq' s + evalSeq' (Consumer _) = loop (initSeq aenv s) + evalSeq' (Reify _) = reify (initSeq aenv s) + + -- Initialize the producers and the accumulators of the consumers + -- with the given array enviroment. + initSeq :: forall senv arrs'. + Val aenv + -> PreOpenSeq DelayedOpenAcc aenv senv arrs' + -> ExecSeq senv arrs' + initSeq aenv s = + case s of + Producer p s' -> ExecP window0 (initProducer p) (initSeq aenv s') + Consumer c -> ExecC (initConsumer c) + Reify ix -> ExecR (cursor0 ix) + + -- Generate a list from the sequence. + reify :: forall arrs. ExecSeq () [arrs] + -> [arrs] + reify s = case step s Empty' of + (Just s', a) -> a ++ reify s' + (Nothing, a) -> a + + -- Iterate the given sequence until it terminates. + -- A sequence only terminates when one of the producers are exhausted. + loop :: Arrays arrs + => ExecSeq () arrs + -> arrs + loop s = + case step' s of + (Nothing, arrs) -> arrs + (Just s', _) -> loop s' + + where + step' :: ExecSeq () arrs -> (Maybe (ExecSeq () arrs), arrs) + step' s = step s Empty' + + -- One iteration of a sequence. + step :: forall senv arrs'. + ExecSeq senv arrs' + -> Val' senv + -> (Maybe (ExecSeq senv arrs'), arrs') + step s senv = + case s of + ExecP w p s' -> + let (c, mp') = produce p senv + finished = 0 == clen (w !# minCursor s') + w' = if finished then moveWin w c else w + (ms'', a) = step s' (senv `Push'` w') + in case ms'' of + Nothing -> (Nothing, a) + Just s'' | finished + , Just p' <- mp' + -> (Just (ExecP w' p' s''), a) + | not finished + -> (Just (ExecP w' p s''), a) + | otherwise + -> (Nothing, a) + ExecC c -> let (c', acc) = consume c senv + in (Just (ExecC c'), acc) + ExecR ix -> let c = prjChunk ix senv in (Just (ExecR (moveCursor (clen c) ix)), toListChunk c) + + evalA :: DelayedOpenAcc aenv a -> a + evalA acc = evalOpenAcc acc aenv + + evalAF :: DelayedOpenAfun aenv f -> f + evalAF f = evalOpenAfun f aenv + + evalE :: DelayedExp aenv t -> t + evalE exp = evalExp exp aenv + + evalF :: DelayedFun aenv f -> f + evalF fun = evalFun fun aenv + + initProducer :: forall a senv. + Producer DelayedOpenAcc aenv senv a + -> ExecP senv a + initProducer p = + case p of + StreamIn arrs -> ExecStreamIn 1 arrs + ToSeq sliceIndex slix (delayed -> Delayed sh ix _) -> + let n = R.size (R.sliceShape sliceIndex (fromElt sh)) + k = elemsPerChunk conf n + in ExecStreamIn k (toSeqOp sliceIndex slix (fromFunction sh ix)) + MapSeq f x -> ExecMap (mapChunk (evalAF f)) (cursor0 x) + ChunkedMapSeq f x -> ExecMap (evalAF f) (cursor0 x) + ZipWithSeq f x y -> ExecZipWith (zipWithChunk (evalAF f)) (cursor0 x) (cursor0 y) + ScanSeq f e x -> ExecScan scanner (evalE e) (cursor0 x) + where + scanner a c = + let v0 = chunkElems c + (v1, a') = scanl'Op (evalF f) a (delayArray v0) + in (vec2Chunk v1, fromScalar a') + + initConsumer :: forall a senv. + Consumer DelayedOpenAcc aenv senv a + -> ExecC senv a + initConsumer c = + case c of + FoldSeq f e x -> + let f' = evalF f + a0 = fromFunction (Z :. chunkSize conf) (const (evalE e)) + consumer v c = zipWith'Op f' (delayArray v) (delayArray (chunkElems c)) + finalizer = fold1Op f' . delayArray + in ExecFold consumer finalizer a0 (cursor0 x) + FoldSeqFlatten f acc x -> + let f' = evalAF f + a0 = evalA acc + consumer a c = f' a (chunkShapes c) (chunkElems c) + in ExecFold consumer id a0 (cursor0 x) + Stuple t -> + let initTup :: Atuple (Consumer DelayedOpenAcc aenv senv) t -> Atuple (ExecC senv) t + initTup NilAtup = NilAtup + initTup (SnocAtup t c) = SnocAtup (initTup t) (initConsumer c) + in ExecStuple (initTup t) + + delayed :: DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e) + delayed AST.Manifest{} = $internalError "evalOpenAcc" "expected delayed array" + delayed AST.Delayed{..} = Delayed (evalExp extentD aenv) + (evalFun indexD aenv) + (evalFun linearIndexD aenv) + +produce :: Arrays a => ExecP senv a -> Val' senv -> (Chunk a, Maybe (ExecP senv a)) +produce p senv = + case p of + ExecStreamIn k xs -> + let (xs', xs'') = (take k xs, drop k xs) + c = fromListChunk xs' + mp = if null xs'' + then Nothing + else Just (ExecStreamIn k xs'') + in (c, mp) + ExecMap f x -> + let c = prjChunk x senv + in (f c, Just $ ExecMap f (moveCursor (clen c) x)) + ExecZipWith f x y -> + let c1 = prjChunk x senv + c2 = prjChunk y senv + k = clen c1 `min` clen c2 + in (f c1 c2, Just $ ExecZipWith f (moveCursor k x) (moveCursor k y)) + ExecScan scanner a x -> + let c = prjChunk x senv + (c', a') = scanner a c + k = clen c + in (c', Just $ ExecScan scanner a' (moveCursor k x)) + +consume :: forall senv a. ExecC senv a -> Val' senv -> (ExecC senv a, a) +consume c senv = + case c of + ExecFold f g acc x -> + let c = prjChunk x senv + acc' = f acc c + -- Even though we call g here, lazy evaluation should guarantee it is + -- only ever called once. + in (ExecFold f g acc' (moveCursor (clen c) x), g acc') + ExecStuple t -> + let consT :: Atuple (ExecC senv) t -> (Atuple (ExecC senv) t, t) + consT NilAtup = (NilAtup, ()) + consT (SnocAtup t c) | (c', acc) <- consume c senv + , (t', acc') <- consT t + = (SnocAtup t' c', (acc', acc)) + (t', acc) = consT t + in (ExecStuple t', toAtuple acc) + +evalExtend :: Extend DelayedOpenAcc aenv aenv' -> Val aenv -> Val aenv' +evalExtend BaseEnv aenv = aenv +evalExtend (PushEnv ext1 ext2) aenv | aenv' <- evalExtend ext1 aenv + = Push aenv' (evalOpenAcc ext2 aenv') + +delayArray :: Array sh e -> Delayed (Array sh e) +delayArray arr@(Array _ adata) = Delayed (shape arr) (arr!) (toElt . unsafeIndexArrayData adata) + +fromScalar :: Scalar a -> a +fromScalar = (!Z) + +concatOp :: forall e. Elt e => [Vector e] -> Vector e +concatOp = concatVectors + +fetchAllOp :: (Shape sh, Elt e) => Segments sh -> Vector e -> [Array sh e] +fetchAllOp segs elts + | (offsets, n) <- offsetsOp segs + , (n ! Z) <= size (shape elts) + = [fetch (segs ! (Z :. i)) (offsets ! (Z :. i)) | i <- [0 .. size (shape segs) - 1]] + | otherwise = error $ "illegal argument to fetchAllOp" + where + fetch sh offset = fromFunction sh (\ ix -> elts ! (Z :. ((toIndex sh ix) + offset))) + +dropOp :: Elt e => Int -> Vector e -> Vector e +dropOp i v -- TODO + -- * Implement using C-style pointer-plus. + -- ; dropOp is used often (from prjChunk), + -- so it ought to be efficient O(1). + | n <- size (shape v) + , i <= n + , i >= 0 + = fromFunction (Z :. n - i) (\ (Z :. j) -> v ! (Z :. i + j)) + | otherwise = error $ "illegal argument to drop" + +offsetsOp :: Shape sh => Segments sh -> (Vector Int, Scalar Int) +offsetsOp segs = scanl'Op (+) 0 $ delayArray (mapOp size (delayArray segs)) + diff --git a/icebox/Vec.hs b/icebox/Vec.hs new file mode 100644 index 000000000..5a324ed8e --- /dev/null +++ b/icebox/Vec.hs @@ -0,0 +1,99 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_HADDOCK hide #-} +-- | +-- Module : Data.Array.Accelerate.Representation.Vec +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Representation.Vec + where + +import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Representation.Type +import qualified Data.Primitive.Vec as Prim + +import Control.Monad.ST +import Data.Primitive.ByteArray +import Data.Primitive.Types +import Language.Haskell.TH.Extra + +import GHC.Base ( Int(..), Int#, (-#) ) +import GHC.TypeNats + + +-- | Declares the size of a SIMD vector and the type of its elements. This +-- data type is used to denote the relation between a vector type (Vec +-- n single) with its tuple representation (tuple). Conversions between +-- those types are exposed through 'pack' and 'unpack'. +-- +data VecR (n :: Nat) single tuple where + VecRnil :: SingleType s -> VecR 0 s () + VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s) + + +vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s) +vecRvector = uncurry VectorType . go + where + go :: VecR n s tuple -> (Int, SingleType s) + go (VecRnil tp) = (0, tp) + go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp) + +vecRSingle :: KnownNat n => VecR n s tuple -> SingleType s +vecRSingle vecr = let (VectorType _ s) = vecRvector vecr in s + +vecRtuple :: VecR n s tuple -> TypeR tuple +vecRtuple = snd . go + where + go :: VecR n s tuple -> (SingleType s, TypeR tuple) + go (VecRnil tp) = (tp, TupRunit) + go (VecRsucc vec) | (tp, tuple) <- go vec = (tp, TupRpair tuple (TupRsingle (SingleScalarType tp))) + +pack :: forall n single tuple. KnownNat n => VecR n single tuple -> tuple -> Vec n single +pack vecR tuple + | VectorType n single <- vecRvector vecR + , SingleDict <- singleDict single + = runST $ do + mba <- newByteArray (n * sizeOf (undefined :: single)) + go (n - 1) vecR tuple mba + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + where + go :: Prim single => Int -> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s () + go _ (VecRnil _) () _ = return () + go i (VecRsucc r) (xs, x) mba = do + writeByteArray mba i x + go (i - 1) r xs mba + +unpack :: forall n single tuple. KnownNat n => VecR n single tuple -> Vec n single -> tuple +unpack vecR (Vec ba#) + | VectorType n single <- vecRvector vecR + , (I# n#) <- n + , SingleDict <- singleDict single + = go (n# -# 1#) vecR + where + go :: Prim single => Int# -> VecR n' single tuple' -> tuple' + go _ (VecRnil _) = () + go i# (VecRsucc r) = x `seq` xs `seq` (xs, x) + where + xs = go (i# -# 1#) r + x = indexByteArray# ba# i# + +rnfVecR :: VecR n single tuple -> () +rnfVecR (VecRnil tp) = rnfSingleType tp +rnfVecR (VecRsucc vec) = rnfVecR vec + +liftVecR :: VecR n single tuple -> CodeQ (VecR n single tuple) +liftVecR (VecRnil tp) = [|| VecRnil $$(liftSingleType tp) ||] +liftVecR (VecRsucc vec) = [|| VecRsucc $$(liftVecR vec) ||] + diff --git a/icebox/ci-macos.yml b/icebox/ci-macos.yml index f57a894ed..2d55658fc 100644 --- a/icebox/ci-macos.yml +++ b/icebox/ci-macos.yml @@ -25,7 +25,6 @@ jobs: - "9.0" - "8.10" - "8.8" - - "8.6" env: STACK_FLAGS: "--flag accelerate:nofib" HADDOCK_FLAGS: "--haddock --no-haddock-deps --no-haddock-hyperlink-source --haddock-arguments=\"--no-print-missing-docs\"" diff --git a/icebox/ci-windows.yml b/icebox/ci-windows.yml index 02ab2a953..410b79100 100644 --- a/icebox/ci-windows.yml +++ b/icebox/ci-windows.yml @@ -25,7 +25,6 @@ jobs: - "9.0" - "8.10" - "8.8" - - "8.6" env: __COMPAT_LAYER: "" diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 929f2fe21..a6aad65fd 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -51,7 +51,7 @@ -- reference implementation defining the semantics of the Accelerate language -- -- * : --- implementation supporting parallel execution on multicore CPUs (e.g. x86). +-- implementation supporting parallel execution on multicore CPUs (e.g. x86-64, AARCH64). -- -- * : -- implementation supporting parallel execution on CUDA-capable NVIDIA GPUs. @@ -179,9 +179,9 @@ module Data.Array.Accelerate ( -- *** Array shapes & indices -- $shapes_and_indices -- - Z(..), (:.)(..), + Z, (:.), DIM0, DIM1, DIM2, DIM3, DIM4, DIM5, DIM6, DIM7, DIM8, DIM9, - Shape, Slice(..), All(..), Any(..), + Shape, Slice(..), All, Any, -- Split(..), Divide(..), Division(..), -- ** Array access @@ -309,19 +309,21 @@ module Data.Array.Accelerate ( Exp, -- ** SIMD vectors - Vec, VecElt, + Vec, SIMD, -- ** Type classes -- *** Basic type classes - Eq(..), - Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, + Eq(..), VEq(..), + Ord(..), VOrd(..), Ordering, pattern LT, pattern EQ, pattern GT, Enum, succ, pred, Bounded, minBound, maxBound, + -- Functor(..), (<$>), ($>), void, -- Monad(..), -- *** Numeric type classes Num, (+), (-), (*), negate, abs, signum, fromInteger, + VNum(..), Integral, quot, rem, div, mod, quotRem, divMod, Rational(..), Fractional, (/), recip, fromRational, @@ -330,6 +332,7 @@ module Data.Array.Accelerate ( RealFloat(..), -- *** Numeric conversion classes + FromBool(..), FromIntegral(..), ToFloating(..), @@ -348,15 +351,16 @@ module Data.Array.Accelerate ( pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, - pattern Z_, pattern Ix, pattern (::.), pattern All_, pattern Any_, + pattern Z, pattern (:.), pattern All, pattern Any, pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, - pattern Vec2, pattern V2, - pattern Vec3, pattern V3, - pattern Vec4, pattern V4, - pattern Vec8, pattern V8, - pattern Vec16, pattern V16, + pattern SIMD, + V2, pattern V2, + V3, pattern V3, + V4, pattern V4, + V8, pattern V8, + V16, pattern V16, mkPattern, mkPatterns, @@ -367,14 +371,20 @@ module Data.Array.Accelerate ( -- *** Tuples fst, afst, snd, asnd, curry, uncurry, + -- *** SIMD vectors + splat, pack, unpack, insert, extract, shuffle, + -- *** Flow control - (?), match, cond, while, iterate, + (?), select, match, cond, while, iterate, -- *** Scalar reduction sfoldl, -- *** Logical operations - (&&), (||), not, + not, vnot, + (&&), (&&!), (&&*), + (||), (||!), (||*), + vand, vor, -- *** Numeric operations subtract, even, odd, gcd, lcm, (^), (^^), @@ -386,7 +396,7 @@ module Data.Array.Accelerate ( intersect, -- *** Conversions - ord, chr, boolToInt, bitcast, + ord, chr, bitcast, -- --------------------------------------------------------------------------- -- * Foreign Function Interface (FFI) @@ -418,18 +428,14 @@ module Data.Array.Accelerate ( -- --------------------------------------------------------------------------- -- Types - Int, Int8, Int16, Int32, Int64, - Word, Word8, Word16, Word32, Word64, - Half(..), Float, Double, - Bool(..), pattern True_, pattern False_, - Maybe(..), pattern Nothing_, pattern Just_, - Either(..), pattern Left_, pattern Right_, + Int, Int8, Int16, Int32, Int64, Int128, + Word, Word8, Word16, Word32, Word64, Word128, + Half(..), Float, Double, Float16, Float32, Float64, Float128, + Bool, pattern True, pattern False, + Maybe, pattern Nothing, pattern Just, + Either, pattern Left, pattern Right, Char, - CFloat, CDouble, - CShort, CUShort, CInt, CUInt, CLong, CULong, CLLong, CULLong, - CChar, CSChar, CUChar, - ) where import Data.Array.Accelerate.Classes.Bounded @@ -437,6 +443,7 @@ import Data.Array.Accelerate.Classes.Enum import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.Fractional +import Data.Array.Accelerate.Classes.FromBool import Data.Array.Accelerate.Classes.FromIntegral import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num @@ -445,20 +452,25 @@ import Data.Array.Accelerate.Classes.Rational import Data.Array.Accelerate.Classes.RealFloat import Data.Array.Accelerate.Classes.RealFrac import Data.Array.Accelerate.Classes.ToFloating +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Classes.VNum +import Data.Array.Accelerate.Classes.VOrd import Data.Array.Accelerate.Data.Either import Data.Array.Accelerate.Data.Maybe import Data.Array.Accelerate.Language import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.SIMD +import Data.Array.Accelerate.Pattern.Shape +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Pattern.TH import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Pretty () -- show instances import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array ( Array, Arrays, Scalar, Vector, Matrix, Segments, fromFunction, fromFunctionM, toList, fromList ) import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Shape hiding ( size, toIndex, fromIndex, intersect ) +import Data.Array.Accelerate.Sugar.Shape hiding ( Z(..), (:.)(..), Any(..), All(..), size, toIndex, fromIndex, intersect ) import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Data.Primitive.Vec import qualified Data.Array.Accelerate.Sugar.Array as S import qualified Data.Array.Accelerate.Sugar.Shape as S @@ -688,4 +700,3 @@ arrayReshape = S.reshape -- * : binary serialisation of arrays using -- * : efficient boxed and unboxed one-dimensional arrays -- - diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index a6d0f75f7..40bc8fd6a 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -1,4 +1,6 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -7,6 +9,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_HADDOCK hide #-} @@ -78,7 +81,7 @@ module Data.Array.Accelerate.AST ( -- * Internal AST -- ** Array computations Afun, PreAfun, OpenAfun, PreOpenAfun(..), - Acc, OpenAcc(..), PreOpenAcc(..), Direction(..), Message(..), + Acc, OpenAcc(..), PreOpenAcc(..), Direction(..), Message(..), RescaleFactor, ALeftHandSide, ArrayVar, ArrayVars, -- ** Scalar expressions @@ -86,15 +89,14 @@ module Data.Array.Accelerate.AST ( Fun, OpenFun(..), Exp, OpenExp(..), Boundary(..), - PrimConst(..), PrimFun(..), - PrimBool, + PrimBool, PrimMask, PrimMaybe, + BitOrMask, -- ** Extracting type information HasArraysR(..), arrayR, expType, - primConstType, primFunType, -- ** Normal-form @@ -109,7 +111,6 @@ module Data.Array.Accelerate.AST ( rnfExpVar, rnfBoundary, rnfConst, - rnfPrimConst, rnfPrimFun, -- ** Template Haskell @@ -123,7 +124,6 @@ module Data.Array.Accelerate.AST ( liftELeftHandSide, liftExpVar, liftBoundary, - liftPrimConst, liftPrimFun, liftMessage, @@ -144,7 +144,6 @@ import Data.Array.Accelerate.Representation.Slice import Data.Array.Accelerate.Representation.Stencil import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Type import Data.Primitive.Vec @@ -159,8 +158,6 @@ import Language.Haskell.TH.Extra ( CodeQ ) import qualified Language.Haskell.TH.Extra as TH import qualified Language.Haskell.TH.Syntax as TH -import GHC.TypeLits - -- Array expressions -- ----------------- @@ -197,8 +194,8 @@ type ALeftHandSide = LeftHandSide ArrayR type ArrayVar = Var ArrayR type ArrayVars aenv = Vars ArrayR aenv --- Bool is not a primitive type -type PrimBool = TAG +type PrimBool = Bit +type PrimMask n = Vec n Bit type PrimMaybe a = (TAG, ((), a)) -- Trace messages @@ -208,6 +205,13 @@ data Message a where -> Text -> Message a +-- Coercing an array to a different type may involve scaling the size of the +-- array by the given factor. Positive values mean the size of the innermost +-- dimension is multiplied by that value (i.e. the number of elements in the +-- array grows by that factor), negative meaning it shrinks. +-- +type RescaleFactor = INT + -- | Collective array computations parametrised over array variables -- represented with de Bruijn indices. -- @@ -290,6 +294,21 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where -> acc aenv arrs2 -> PreOpenAcc acc aenv arrs2 + -- Reinterpret the bits of the array as a different type. The size of the + -- innermost dimension is adjusted as necessary. The old and new sizes must be + -- compatible, but this may not be checked; e.g. in the conversion + -- + -- > Array (Z :. n) Float <~~> Array (Z :. m) (Vec 3 Float) + -- + -- then we require + -- + -- > (m, r) = quotRem n 3 and r == 0 + -- + Acoerce :: RescaleFactor + -> TypeR b + -> acc aenv (Array (sh, INT) a) + -> PreOpenAcc acc aenv (Array (sh, INT) b) + -- Array inlet. Triggers (possibly) asynchronous host->device transfer if -- necessary. -- @@ -368,18 +387,18 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where -- Fold :: Fun aenv (e -> e -> e) -- combination function -> Maybe (Exp aenv e) -- default value - -> acc aenv (Array (sh, Int) e) -- folded array + -> acc aenv (Array (sh, INT) e) -- folded array -> PreOpenAcc acc aenv (Array sh e) -- Segmented fold along the innermost dimension of an array with a given -- /associative/ function -- - FoldSeg :: IntegralType i + FoldSeg :: SingleIntegralType i -> Fun aenv (e -> e -> e) -- combination function -> Maybe (Exp aenv e) -- default value - -> acc aenv (Array (sh, Int) e) -- folded array + -> acc aenv (Array (sh, INT) e) -- folded array -> acc aenv (Segments i) -- segment descriptor - -> PreOpenAcc acc aenv (Array (sh, Int) e) + -> PreOpenAcc acc aenv (Array (sh, INT) e) -- Haskell-style scan of a linear array with a given -- /associative/ function and optionally an initial element @@ -389,8 +408,8 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where Scan :: Direction -> Fun aenv (e -> e -> e) -- combination function -> Maybe (Exp aenv e) -- initial value - -> acc aenv (Array (sh, Int) e) - -> PreOpenAcc acc aenv (Array (sh, Int) e) + -> acc aenv (Array (sh, INT) e) + -> PreOpenAcc acc aenv (Array (sh, INT) e) -- Like 'Scan', but produces a rightmost (in case of a left-to-right scan) -- fold value and an array with the same length as the input array (the @@ -399,8 +418,8 @@ data PreOpenAcc (acc :: Type -> Type -> Type) aenv a where Scan' :: Direction -> Fun aenv (e -> e -> e) -- combination function -> Exp aenv e -- initial value - -> acc aenv (Array (sh, Int) e) - -> PreOpenAcc acc aenv (Array (sh, Int) e, Array sh e) + -> acc aenv (Array (sh, INT) e) + -> PreOpenAcc acc aenv (Array (sh, INT) e, Array sh e) -- Generalised forward permutation is characterised by a permutation function -- that determines for each element of the source array where it should go in @@ -549,15 +568,31 @@ data OpenExp env aenv t where Nil :: OpenExp env aenv () -- SIMD vectors - VecPack :: KnownNat n - => VecR n s tup - -> OpenExp env aenv tup - -> OpenExp env aenv (Vec n s) + Extract :: ScalarType (Vec n a) + -> SingleIntegralType i + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv i + -> OpenExp env aenv a + + Insert :: ScalarType (Vec n a) + -> SingleIntegralType i + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv i + -> OpenExp env aenv a + -> OpenExp env aenv (Vec n a) + + Shuffle :: ScalarType (Vec m a) + -> SingleIntegralType i + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv (Vec m i) + -> OpenExp env aenv (Vec m a) - VecUnpack :: KnownNat n - => VecR n s tup - -> OpenExp env aenv (Vec n s) - -> OpenExp env aenv tup + Select :: ScalarType (Vec n a) + -> OpenExp env aenv (PrimMask n) + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv (Vec n a) + -> OpenExp env aenv (Vec n a) -- Array indices & shapes IndexSlice :: SliceIndex slix sl co sh @@ -574,16 +609,17 @@ data OpenExp env aenv t where ToIndex :: ShapeR sh -> OpenExp env aenv sh -- shape of the array -> OpenExp env aenv sh -- index into the array - -> OpenExp env aenv Int + -> OpenExp env aenv INT FromIndex :: ShapeR sh -> OpenExp env aenv sh -- shape of the array - -> OpenExp env aenv Int -- index into linear representation + -> OpenExp env aenv INT -- index into linear representation -> OpenExp env aenv sh -- Case statement - Case :: OpenExp env aenv TAG - -> [(TAG, OpenExp env aenv b)] -- list of equations + Case :: TagType tag + -> OpenExp env aenv tag + -> [(tag, OpenExp env aenv b)] -- list of equations -> Maybe (OpenExp env aenv b) -- default case -> OpenExp env aenv b @@ -604,9 +640,6 @@ data OpenExp env aenv t where -> t -> OpenExp env aenv t - PrimConst :: PrimConst t - -> OpenExp env aenv t - -- Primitive scalar operations PrimApp :: PrimFun (a -> r) -> OpenExp env aenv a @@ -619,7 +652,7 @@ data OpenExp env aenv t where -> OpenExp env aenv t LinearIndex :: ArrayVar aenv (Array dim t) - -> OpenExp env aenv Int + -> OpenExp env aenv INT -> OpenExp env aenv t -- Array shape. @@ -630,33 +663,31 @@ data OpenExp env aenv t where -- Number of elements of an array given its shape ShapeSize :: ShapeR dim -> OpenExp env aenv dim - -> OpenExp env aenv Int + -> OpenExp env aenv INT -- Unsafe operations (may fail or result in undefined behaviour) -- An unspecified bit pattern Undef :: ScalarType t -> OpenExp env aenv t - -- Reinterpret the bits of a value as a different type - Coerce :: BitSizeEq a b - => ScalarType a + -- Reinterpret the bits of a value as a different type. + -- + -- The types must have the same bit size, but that constraint is not include + -- at this point because GHC's typelits solver is often not powerful enough to + -- discharge that constraint. ---TLM 2022-09-20 + Bitcast :: ScalarType a -> ScalarType b -> OpenExp env aenv a -> OpenExp env aenv b --- |Primitive constant values --- -data PrimConst ty where - - -- constants from Bounded - PrimMinBound :: BoundedType a -> PrimConst a - PrimMaxBound :: BoundedType a -> PrimConst a - - -- constant from Floating - PrimPi :: FloatingType a -> PrimConst a +-- | A bit mask at the width of the corresponding type +-- +type family BitOrMask a where + BitOrMask (Vec n _) = PrimMask n + BitOrMask _ = PrimBool --- |Primitive scalar operations +-- | Primitive scalar operations -- data PrimFun sig where @@ -667,6 +698,8 @@ data PrimFun sig where PrimNeg :: NumType a -> PrimFun (a -> a) PrimAbs :: NumType a -> PrimFun (a -> a) PrimSig :: NumType a -> PrimFun (a -> a) + PrimVAdd :: NumType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVMul :: NumType (Vec n a) -> PrimFun (Vec n a -> a) -- operators from Integral PrimQuot :: IntegralType a -> PrimFun ((a, a) -> a) @@ -677,17 +710,22 @@ data PrimFun sig where PrimDivMod :: IntegralType a -> PrimFun ((a, a) -> (a, a)) -- operators from Bits & FiniteBits - PrimBAnd :: IntegralType a -> PrimFun ((a, a) -> a) - PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a) - PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a) - PrimBNot :: IntegralType a -> PrimFun (a -> a) - PrimBShiftL :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBShiftR :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBRotateL :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimBRotateR :: IntegralType a -> PrimFun ((a, Int) -> a) - PrimPopCount :: IntegralType a -> PrimFun (a -> Int) - PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> Int) - PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> Int) + PrimBAnd :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBOr :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBXor :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBNot :: IntegralType a -> PrimFun (a -> a) + PrimBShiftL :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBShiftR :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBRotateL :: IntegralType a -> PrimFun ((a, a) -> a) + PrimBRotateR :: IntegralType a -> PrimFun ((a, a) -> a) + PrimPopCount :: IntegralType a -> PrimFun (a -> a) + PrimCountLeadingZeros :: IntegralType a -> PrimFun (a -> a) + PrimCountTrailingZeros :: IntegralType a -> PrimFun (a -> a) + PrimBReverse :: IntegralType a -> PrimFun (a -> a) + PrimBSwap :: IntegralType a -> PrimFun (a -> a) -- prerequisite: BitSize a % 16 == 0 + PrimVBAnd :: IntegralType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVBOr :: IntegralType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVBXor :: IntegralType (Vec n a) -> PrimFun (Vec n a -> a) -- operators from Fractional and Floating PrimFDiv :: FloatingType a -> PrimFun ((a, a) -> a) @@ -710,8 +748,6 @@ data PrimFun sig where PrimFPow :: FloatingType a -> PrimFun ((a, a) -> a) PrimLogBase :: FloatingType a -> PrimFun ((a, a) -> a) - -- FIXME: add missing operations from RealFrac & RealFloat - -- operators from RealFrac PrimTruncate :: FloatingType a -> IntegralType b -> PrimFun (a -> b) PrimRound :: FloatingType a -> IntegralType b -> PrimFun (a -> b) @@ -721,18 +757,20 @@ data PrimFun sig where -- operators from RealFloat PrimAtan2 :: FloatingType a -> PrimFun ((a, a) -> a) - PrimIsNaN :: FloatingType a -> PrimFun (a -> PrimBool) - PrimIsInfinite :: FloatingType a -> PrimFun (a -> PrimBool) + PrimIsNaN :: FloatingType a -> PrimFun (a -> BitOrMask a) + PrimIsInfinite :: FloatingType a -> PrimFun (a -> BitOrMask a) -- relational and equality operators - PrimLt :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimGt :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimLtEq :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimGtEq :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimEq :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimNEq :: SingleType a -> PrimFun ((a, a) -> PrimBool) - PrimMax :: SingleType a -> PrimFun ((a, a) -> a) - PrimMin :: SingleType a -> PrimFun ((a, a) -> a) + PrimLt :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimGt :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimLtEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimGtEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimNEq :: ScalarType a -> PrimFun ((a, a) -> BitOrMask a) + PrimMin :: ScalarType a -> PrimFun ((a, a) -> a) + PrimMax :: ScalarType a -> PrimFun ((a, a) -> a) + PrimVMin :: ScalarType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVMax :: ScalarType (Vec n a) -> PrimFun (Vec n a -> a) -- logical operators -- @@ -744,13 +782,17 @@ data PrimFun sig where -- short-circuiting, while (&&!) and (||!) are strict versions of these -- operators, which are defined using PrimLAnd and PrimLOr. -- - PrimLAnd :: PrimFun ((PrimBool, PrimBool) -> PrimBool) - PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool) - PrimLNot :: PrimFun (PrimBool -> PrimBool) + PrimLAnd :: BitType a -> PrimFun ((a, a) -> a) + PrimLOr :: BitType a -> PrimFun ((a, a) -> a) + PrimLNot :: BitType a -> PrimFun (a -> a) + PrimVLAnd :: BitType (Vec n a) -> PrimFun (Vec n a -> a) + PrimVLOr :: BitType (Vec n a) -> PrimFun (Vec n a -> a) -- general conversion between types PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b) PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b) + PrimToBool :: IntegralType a -> BitType b -> PrimFun (a -> b) + PrimFromBool :: BitType a -> IntegralType b -> PrimFun (a -> b) -- Type utilities @@ -772,6 +814,8 @@ instance HasArraysR acc => HasArraysR (PreOpenAcc acc) where arraysR (Apair as bs) = TupRpair (arraysR as) (arraysR bs) arraysR Anil = TupRunit arraysR (Atrace _ _ bs) = arraysR bs + arraysR (Acoerce _ bR as) = let ArrayR shR _ = arrayR as + in arraysRarray shR bR arraysR (Apply aR _ _) = aR arraysR (Aforeign r _ _ _) = r arraysR (Acond _ a _) = arraysR a @@ -812,39 +856,49 @@ expType = \case Foreign tR _ _ _ -> tR Pair e1 e2 -> TupRpair (expType e1) (expType e2) Nil -> TupRunit - VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR - VecUnpack vecR _ -> vecRtuple vecR + Extract tR _ _ _ -> TupRsingle (scalar tR) + where + scalar :: ScalarType (Vec n a) -> ScalarType a + scalar (NumScalarType t) = NumScalarType (num t) + scalar (BitScalarType t) = BitScalarType (bit t) + + bit :: BitType (Vec n a) -> BitType a + bit TypeMask{} = TypeBit + + num :: NumType (Vec n a) -> NumType a + num (IntegralNumType t) = IntegralNumType (integral t) + num (FloatingNumType t) = FloatingNumType (floating t) + + integral :: IntegralType (Vec n a) -> IntegralType a + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType _ t) = SingleIntegralType t + + floating :: FloatingType (Vec n a) -> FloatingType a + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType _ t) = SingleFloatingType t + -- + Insert t _ _ _ _ -> TupRsingle t + Shuffle t _ _ _ _ -> TupRsingle t + Select t _ _ _ -> TupRsingle t IndexSlice si _ _ -> shapeType $ sliceShapeR si IndexFull si _ _ -> shapeType $ sliceDomainR si - ToIndex{} -> TupRsingle scalarTypeInt - FromIndex shr _ _ -> shapeType shr - Case _ ((_,e):_) _ -> expType e - Case _ [] (Just e) -> expType e + ToIndex{} -> TupRsingle scalarType + FromIndex shR _ _ -> shapeType shR + Case _ _ ((_,e):_) _ -> expType e + Case _ _ [] (Just e) -> expType e Case{} -> internalError "empty case encountered" Cond _ e _ -> expType e While _ (Lam lhs _) _ -> lhsToTupR lhs While{} -> error "What's the matter, you're running in the shadows" Const tR _ -> TupRsingle tR - PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c PrimApp f _ -> snd $ primFunType f Index (Var repr _) _ -> arrayRtype repr LinearIndex (Var repr _) _ -> arrayRtype repr Shape (Var repr _) -> shapeType $ arrayRshape repr - ShapeSize{} -> TupRsingle scalarTypeInt + ShapeSize{} -> TupRsingle (scalarType @INT) Undef tR -> TupRsingle tR - Coerce _ tR _ -> TupRsingle tR - -primConstType :: PrimConst a -> SingleType a -primConstType = \case - PrimMinBound t -> bounded t - PrimMaxBound t -> bounded t - PrimPi t -> floating t - where - bounded :: BoundedType a -> SingleType a - bounded (IntegralBoundedType t) = NumSingleType $ IntegralNumType t + Bitcast _ tR _ -> TupRsingle tR - floating :: FloatingType t -> SingleType t - floating = NumSingleType . FloatingNumType primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b) primFunType = \case @@ -855,27 +909,34 @@ primFunType = \case PrimNeg t -> unary' $ num t PrimAbs t -> unary' $ num t PrimSig t -> unary' $ num t + PrimVAdd t -> unary (num t) (scalar_num t) + PrimVMul t -> unary (num t) (scalar_num t) -- Integral PrimQuot t -> binary' $ integral t PrimRem t -> binary' $ integral t - PrimQuotRem t -> unary' $ integral t `TupRpair` integral t + PrimQuotRem t -> unary' $ integral t `TupRpair` integral t PrimIDiv t -> binary' $ integral t PrimMod t -> binary' $ integral t - PrimDivMod t -> unary' $ integral t `TupRpair` integral t + PrimDivMod t -> unary' $ integral t `TupRpair` integral t -- Bits & FiniteBits PrimBAnd t -> binary' $ integral t PrimBOr t -> binary' $ integral t PrimBXor t -> binary' $ integral t - PrimBNot t -> unary' $ integral t - PrimBShiftL t -> (integral t `TupRpair` tint, integral t) - PrimBShiftR t -> (integral t `TupRpair` tint, integral t) - PrimBRotateL t -> (integral t `TupRpair` tint, integral t) - PrimBRotateR t -> (integral t `TupRpair` tint, integral t) - PrimPopCount t -> unary (integral t) tint - PrimCountLeadingZeros t -> unary (integral t) tint - PrimCountTrailingZeros t -> unary (integral t) tint + PrimBNot t -> unary' $ integral t + PrimBShiftL t -> binary' $ integral t + PrimBShiftR t -> binary' $ integral t + PrimBRotateL t -> binary' $ integral t + PrimBRotateR t -> binary' $ integral t + PrimPopCount t -> unary' $ integral t + PrimCountLeadingZeros t -> unary' $ integral t + PrimCountTrailingZeros t -> unary' $ integral t + PrimBSwap t -> unary' $ integral t + PrimBReverse t -> unary' $ integral t + PrimVBAnd t -> unary (integral t) (scalar_integral t) + PrimVBOr t -> unary (integral t) (scalar_integral t) + PrimVBXor t -> unary (integral t) (scalar_integral t) -- Fractional, Floating PrimFDiv t -> binary' $ floating t @@ -906,8 +967,8 @@ primFunType = \case -- RealFloat PrimAtan2 t -> binary' $ floating t - PrimIsNaN t -> unary (floating t) tbool - PrimIsInfinite t -> unary (floating t) tbool + PrimIsNaN t -> unary (floating t) (mask_floating t) + PrimIsInfinite t -> unary (floating t) (mask_floating t) -- Relational and equality PrimLt t -> compare' t @@ -916,32 +977,91 @@ primFunType = \case PrimGtEq t -> compare' t PrimEq t -> compare' t PrimNEq t -> compare' t - PrimMax t -> binary' $ single t PrimMin t -> binary' $ single t + PrimMax t -> binary' $ single t + PrimVMin t -> unary (single t) (scalar_scalar t) + PrimVMax t -> unary (single t) (scalar_scalar t) -- Logical - PrimLAnd -> binary' tbool - PrimLOr -> binary' tbool - PrimLNot -> unary' tbool + PrimLAnd t -> binary' (bit t) + PrimLOr t -> binary' (bit t) + PrimLNot t -> unary' (bit t) + PrimVLAnd t -> unary (bit t) (scalar_bit t) + PrimVLOr t -> unary (bit t) (scalar_bit t) -- general conversion between types PrimFromIntegral a b -> unary (integral a) (num b) PrimToFloating a b -> unary (num a) (floating b) + PrimToBool a b -> unary (integral a) (bit b) + PrimFromBool a b -> unary (bit a) (integral b) where unary a b = (a, b) unary' a = unary a a binary a b = (a `TupRpair` a, b) binary' a = binary a a - compare' a = binary (single a) tbool - - single = TupRsingle . SingleScalarType - num = TupRsingle . SingleScalarType . NumSingleType - integral = num . IntegralNumType - floating = num . FloatingNumType - - tbool = TupRsingle scalarTypeWord8 - tint = TupRsingle scalarTypeInt + compare' a = binary (single a) (mask_scalar a) + + single = TupRsingle + num = single . NumScalarType + bit = single . BitScalarType + integral = num . IntegralNumType + floating = num . FloatingNumType + + mask_scalar :: ScalarType t -> TypeR (BitOrMask t) + mask_scalar (NumScalarType t) = mask_num t + mask_scalar (BitScalarType t) = mask_bit t + + mask_bit :: BitType t -> TypeR (BitOrMask t) + mask_bit TypeBit = bit TypeBit + mask_bit (TypeMask n) = bit (TypeMask n) + + mask_num :: NumType t -> TypeR (BitOrMask t) + mask_num (IntegralNumType t) = mask_integral t + mask_num (FloatingNumType t) = mask_floating t + + mask_integral :: IntegralType t -> TypeR (BitOrMask t) + mask_integral (VectorIntegralType n _) = single (BitScalarType (TypeMask n)) + mask_integral (SingleIntegralType t) = case t of + TypeInt8 -> single (scalarType @Bit) + TypeInt16 -> single (scalarType @Bit) + TypeInt32 -> single (scalarType @Bit) + TypeInt64 -> single (scalarType @Bit) + TypeInt128 -> single (scalarType @Bit) + TypeWord8 -> single (scalarType @Bit) + TypeWord16 -> single (scalarType @Bit) + TypeWord32 -> single (scalarType @Bit) + TypeWord64 -> single (scalarType @Bit) + TypeWord128 -> single (scalarType @Bit) + + mask_floating :: FloatingType t -> TypeR (BitOrMask t) + mask_floating (VectorFloatingType n _) = single (BitScalarType (TypeMask n)) + mask_floating (SingleFloatingType t) = case t of + TypeFloat16 -> single (scalarType @Bit) + TypeFloat32 -> single (scalarType @Bit) + TypeFloat64 -> single (scalarType @Bit) + TypeFloat128 -> single (scalarType @Bit) + + scalar_scalar :: ScalarType (Vec n a) -> TypeR a + scalar_scalar (NumScalarType t) = scalar_num t + scalar_scalar (BitScalarType t) = scalar_bit t + + scalar_bit :: BitType (Vec n a) -> TypeR a + scalar_bit TypeMask{} = bit TypeBit + + scalar_num :: NumType (Vec n a) -> TypeR a + scalar_num (IntegralNumType t) = scalar_integral t + scalar_num (FloatingNumType t) = scalar_floating t + + scalar_integral :: IntegralType (Vec n a) -> TypeR a + scalar_integral = \case + SingleIntegralType t -> case t of + VectorIntegralType _ t -> integral (SingleIntegralType t) + + scalar_floating :: FloatingType (Vec n a) -> TypeR a + scalar_floating = \case + SingleFloatingType t -> case t of + VectorFloatingType _ t -> floating (SingleFloatingType t) -- Normal form data @@ -996,6 +1116,7 @@ rnfPreOpenAcc rnfA pacc = Apair as bs -> rnfA as `seq` rnfA bs Anil -> () Atrace msg as bs -> rnfM msg `seq` rnfA as `seq` rnfA bs + Acoerce scale bR as -> scale `seq` rnfTypeR bR `seq` rnfA as Apply repr afun acc -> rnfTupR rnfArrayR repr `seq` rnfAF afun `seq` rnfA acc Aforeign repr asm afun a -> rnfTupR rnfArrayR repr `seq` rnf (strForeign asm) `seq` rnfAF afun `seq` rnfA a Acond p a1 a2 -> rnfE p `seq` rnfA a1 `seq` rnfA a2 @@ -1010,7 +1131,7 @@ rnfPreOpenAcc rnfA pacc = Map tp f a -> rnfTypeR tp `seq` rnfF f `seq` rnfA a ZipWith tp f a1 a2 -> rnfTypeR tp `seq` rnfF f `seq` rnfA a1 `seq` rnfA a2 Fold f z a -> rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a - FoldSeg i f z a s -> rnfIntegralType i `seq` rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a `seq` rnfA s + FoldSeg i f z a s -> rnfSingleIntegralType i `seq` rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a `seq` rnfA s Scan d f z a -> d `seq` rnfF f `seq` rnfMaybe rnfE z `seq` rnfA a Scan' d f z a -> d `seq` rnfF f `seq` rnfE z `seq` rnfA a Permute f d p a -> rnfF f `seq` rnfA d `seq` rnfF p `seq` rnfA a @@ -1071,22 +1192,23 @@ rnfOpenExp topExp = Undef tp -> rnfScalarType tp Pair a b -> rnfE a `seq` rnfE b Nil -> () - VecPack vecr e -> rnfVecR vecr `seq` rnfE e - VecUnpack vecr e -> rnfVecR vecr `seq` rnfE e + Extract vR iR v i -> rnfScalarType vR `seq` rnfSingleIntegralType iR `seq` rnfE v `seq` rnfE i + Insert vR iR v i x -> rnfScalarType vR `seq` rnfSingleIntegralType iR `seq` rnfE v `seq` rnfE i `seq` rnfE x + Shuffle eR iR x y i -> rnfScalarType eR `seq` rnfSingleIntegralType iR `seq` rnfE x `seq` rnfE y `seq` rnfE i + Select eR m x y -> rnfScalarType eR `seq` rnfE m `seq` rnfE x `seq` rnfE y IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix FromIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix - Case e rhs def -> rnfE e `seq` rnfList (\(t,c) -> t `seq` rnfE c) rhs `seq` rnfMaybe rnfE def + Case pR p rhs def -> rnfTagType pR `seq` rnfE p `seq` rnfList (\(t,c) -> t `seq` rnfE c) rhs `seq` rnfMaybe rnfE def Cond p e1 e2 -> rnfE p `seq` rnfE e1 `seq` rnfE e2 While p f x -> rnfF p `seq` rnfF f `seq` rnfE x - PrimConst c -> rnfPrimConst c PrimApp f x -> rnfPrimFun f `seq` rnfE x Index a ix -> rnfArrayVar a `seq` rnfE ix LinearIndex a ix -> rnfArrayVar a `seq` rnfE ix Shape a -> rnfArrayVar a ShapeSize shr sh -> rnfShapeR shr `seq` rnfE sh - Coerce t1 t2 e -> rnfScalarType t1 `seq` rnfScalarType t2 `seq` rnfE e + Bitcast t1 t2 e -> rnfScalarType t1 `seq` rnfScalarType t2 `seq` rnfE e rnfExpVar :: ExpVar env t -> () rnfExpVar = rnfVar rnfScalarType @@ -1099,74 +1221,83 @@ rnfConst TupRunit () = () rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf == whnf) rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b -rnfPrimConst :: PrimConst c -> () -rnfPrimConst (PrimMinBound t) = rnfBoundedType t -rnfPrimConst (PrimMaxBound t) = rnfBoundedType t -rnfPrimConst (PrimPi t) = rnfFloatingType t - rnfPrimFun :: PrimFun f -> () -rnfPrimFun (PrimAdd t) = rnfNumType t -rnfPrimFun (PrimSub t) = rnfNumType t -rnfPrimFun (PrimMul t) = rnfNumType t -rnfPrimFun (PrimNeg t) = rnfNumType t -rnfPrimFun (PrimAbs t) = rnfNumType t -rnfPrimFun (PrimSig t) = rnfNumType t -rnfPrimFun (PrimQuot t) = rnfIntegralType t -rnfPrimFun (PrimRem t) = rnfIntegralType t -rnfPrimFun (PrimQuotRem t) = rnfIntegralType t -rnfPrimFun (PrimIDiv t) = rnfIntegralType t -rnfPrimFun (PrimMod t) = rnfIntegralType t -rnfPrimFun (PrimDivMod t) = rnfIntegralType t -rnfPrimFun (PrimBAnd t) = rnfIntegralType t -rnfPrimFun (PrimBOr t) = rnfIntegralType t -rnfPrimFun (PrimBXor t) = rnfIntegralType t -rnfPrimFun (PrimBNot t) = rnfIntegralType t -rnfPrimFun (PrimBShiftL t) = rnfIntegralType t -rnfPrimFun (PrimBShiftR t) = rnfIntegralType t -rnfPrimFun (PrimBRotateL t) = rnfIntegralType t -rnfPrimFun (PrimBRotateR t) = rnfIntegralType t -rnfPrimFun (PrimPopCount t) = rnfIntegralType t -rnfPrimFun (PrimCountLeadingZeros t) = rnfIntegralType t -rnfPrimFun (PrimCountTrailingZeros t) = rnfIntegralType t -rnfPrimFun (PrimFDiv t) = rnfFloatingType t -rnfPrimFun (PrimRecip t) = rnfFloatingType t -rnfPrimFun (PrimSin t) = rnfFloatingType t -rnfPrimFun (PrimCos t) = rnfFloatingType t -rnfPrimFun (PrimTan t) = rnfFloatingType t -rnfPrimFun (PrimAsin t) = rnfFloatingType t -rnfPrimFun (PrimAcos t) = rnfFloatingType t -rnfPrimFun (PrimAtan t) = rnfFloatingType t -rnfPrimFun (PrimSinh t) = rnfFloatingType t -rnfPrimFun (PrimCosh t) = rnfFloatingType t -rnfPrimFun (PrimTanh t) = rnfFloatingType t -rnfPrimFun (PrimAsinh t) = rnfFloatingType t -rnfPrimFun (PrimAcosh t) = rnfFloatingType t -rnfPrimFun (PrimAtanh t) = rnfFloatingType t -rnfPrimFun (PrimExpFloating t) = rnfFloatingType t -rnfPrimFun (PrimSqrt t) = rnfFloatingType t -rnfPrimFun (PrimLog t) = rnfFloatingType t -rnfPrimFun (PrimFPow t) = rnfFloatingType t -rnfPrimFun (PrimLogBase t) = rnfFloatingType t -rnfPrimFun (PrimTruncate f i) = rnfFloatingType f `seq` rnfIntegralType i -rnfPrimFun (PrimRound f i) = rnfFloatingType f `seq` rnfIntegralType i -rnfPrimFun (PrimFloor f i) = rnfFloatingType f `seq` rnfIntegralType i -rnfPrimFun (PrimCeiling f i) = rnfFloatingType f `seq` rnfIntegralType i -rnfPrimFun (PrimIsNaN t) = rnfFloatingType t -rnfPrimFun (PrimIsInfinite t) = rnfFloatingType t -rnfPrimFun (PrimAtan2 t) = rnfFloatingType t -rnfPrimFun (PrimLt t) = rnfSingleType t -rnfPrimFun (PrimGt t) = rnfSingleType t -rnfPrimFun (PrimLtEq t) = rnfSingleType t -rnfPrimFun (PrimGtEq t) = rnfSingleType t -rnfPrimFun (PrimEq t) = rnfSingleType t -rnfPrimFun (PrimNEq t) = rnfSingleType t -rnfPrimFun (PrimMax t) = rnfSingleType t -rnfPrimFun (PrimMin t) = rnfSingleType t -rnfPrimFun PrimLAnd = () -rnfPrimFun PrimLOr = () -rnfPrimFun PrimLNot = () -rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n -rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f +rnfPrimFun = \case + PrimAdd t -> rnfNumType t + PrimSub t -> rnfNumType t + PrimMul t -> rnfNumType t + PrimNeg t -> rnfNumType t + PrimAbs t -> rnfNumType t + PrimSig t -> rnfNumType t + PrimVAdd t -> rnfNumType t + PrimVMul t -> rnfNumType t + PrimQuot t -> rnfIntegralType t + PrimRem t -> rnfIntegralType t + PrimQuotRem t -> rnfIntegralType t + PrimIDiv t -> rnfIntegralType t + PrimMod t -> rnfIntegralType t + PrimDivMod t -> rnfIntegralType t + PrimBAnd t -> rnfIntegralType t + PrimBOr t -> rnfIntegralType t + PrimBXor t -> rnfIntegralType t + PrimBNot t -> rnfIntegralType t + PrimBShiftL t -> rnfIntegralType t + PrimBShiftR t -> rnfIntegralType t + PrimBRotateL t -> rnfIntegralType t + PrimBRotateR t -> rnfIntegralType t + PrimPopCount t -> rnfIntegralType t + PrimCountLeadingZeros t -> rnfIntegralType t + PrimCountTrailingZeros t -> rnfIntegralType t + PrimBReverse t -> rnfIntegralType t + PrimBSwap t -> rnfIntegralType t + PrimVBAnd t -> rnfIntegralType t + PrimVBOr t -> rnfIntegralType t + PrimVBXor t -> rnfIntegralType t + PrimFDiv t -> rnfFloatingType t + PrimRecip t -> rnfFloatingType t + PrimSin t -> rnfFloatingType t + PrimCos t -> rnfFloatingType t + PrimTan t -> rnfFloatingType t + PrimAsin t -> rnfFloatingType t + PrimAcos t -> rnfFloatingType t + PrimAtan t -> rnfFloatingType t + PrimSinh t -> rnfFloatingType t + PrimCosh t -> rnfFloatingType t + PrimTanh t -> rnfFloatingType t + PrimAsinh t -> rnfFloatingType t + PrimAcosh t -> rnfFloatingType t + PrimAtanh t -> rnfFloatingType t + PrimExpFloating t -> rnfFloatingType t + PrimSqrt t -> rnfFloatingType t + PrimLog t -> rnfFloatingType t + PrimFPow t -> rnfFloatingType t + PrimLogBase t -> rnfFloatingType t + PrimTruncate f i -> rnfFloatingType f `seq` rnfIntegralType i + PrimRound f i -> rnfFloatingType f `seq` rnfIntegralType i + PrimFloor f i -> rnfFloatingType f `seq` rnfIntegralType i + PrimCeiling f i -> rnfFloatingType f `seq` rnfIntegralType i + PrimAtan2 t -> rnfFloatingType t + PrimIsNaN t -> rnfFloatingType t + PrimIsInfinite t -> rnfFloatingType t + PrimLt t -> rnfScalarType t + PrimGt t -> rnfScalarType t + PrimLtEq t -> rnfScalarType t + PrimGtEq t -> rnfScalarType t + PrimEq t -> rnfScalarType t + PrimNEq t -> rnfScalarType t + PrimMin t -> rnfScalarType t + PrimMax t -> rnfScalarType t + PrimVMin t -> rnfScalarType t + PrimVMax t -> rnfScalarType t + PrimLAnd t -> rnfBitType t + PrimLOr t -> rnfBitType t + PrimLNot t -> rnfBitType t + PrimVLAnd t -> rnfBitType t + PrimVLOr t -> rnfBitType t + PrimFromIntegral i n -> rnfIntegralType i `seq` rnfNumType n + PrimToFloating n f -> rnfNumType n `seq` rnfFloatingType f + PrimToBool i b -> rnfIntegralType i `seq` rnfBitType b + PrimFromBool b i -> rnfBitType b `seq` rnfIntegralType i -- Template Haskell @@ -1204,6 +1335,7 @@ liftPreOpenAcc liftA pacc = Apair as bs -> [|| Apair $$(liftA as) $$(liftA bs) ||] Anil -> [|| Anil ||] Atrace msg as bs -> [|| Atrace $$(liftMessage (arraysR as) msg) $$(liftA as) $$(liftA bs) ||] + Acoerce scale bR a -> [|| Acoerce scale $$(liftTypeR bR) $$(liftA a) ||] Apply repr f a -> [|| Apply $$(liftArraysR repr) $$(liftAF f) $$(liftA a) ||] Aforeign repr asm f a -> [|| Aforeign $$(liftArraysR repr) $$(liftForeign asm) $$(liftPreOpenAfun liftA f) $$(liftA a) ||] Acond p t e -> [|| Acond $$(liftE p) $$(liftA t) $$(liftA e) ||] @@ -1218,7 +1350,7 @@ liftPreOpenAcc liftA pacc = Map tp f a -> [|| Map $$(liftTypeR tp) $$(liftF f) $$(liftA a) ||] ZipWith tp f a b -> [|| ZipWith $$(liftTypeR tp) $$(liftF f) $$(liftA a) $$(liftA b) ||] Fold f z a -> [|| Fold $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) ||] - FoldSeg i f z a s -> [|| FoldSeg $$(liftIntegralType i) $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) $$(liftA s) ||] + FoldSeg i f z a s -> [|| FoldSeg $$(liftSingleIntegralType i) $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) $$(liftA s) ||] Scan d f z a -> [|| Scan $$(liftDirection d) $$(liftF f) $$(liftMaybe liftE z) $$(liftA a) ||] Scan' d f z a -> [|| Scan' $$(liftDirection d) $$(liftF f) $$(liftE z) $$(liftA a) ||] Permute f d p a -> [|| Permute $$(liftF f) $$(liftA d) $$(liftF p) $$(liftA a) ||] @@ -1291,22 +1423,23 @@ liftOpenExp pexp = Undef tp -> [|| Undef $$(liftScalarType tp) ||] Pair a b -> [|| Pair $$(liftE a) $$(liftE b) ||] Nil -> [|| Nil ||] - VecPack vecr e -> [|| VecPack $$(liftVecR vecr) $$(liftE e) ||] - VecUnpack vecr e -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||] + Extract vR iR v i -> [|| Extract $$(liftScalarType vR) $$(liftSingleIntegralType iR) $$(liftE v) $$(liftE i) ||] + Insert vR iR v i x -> [|| Insert $$(liftScalarType vR) $$(liftSingleIntegralType iR) $$(liftE v) $$(liftE i) $$(liftE x) ||] + Shuffle eR iR x y i -> [|| Shuffle $$(liftScalarType eR) $$(liftSingleIntegralType iR) $$(liftE x) $$(liftE y) $$(liftE i) ||] + Select eR m x y -> [|| Select $$(liftScalarType eR) $$(liftE m) $$(liftE x) $$(liftE y) ||] IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||] IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||] ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] FromIndex shr sh ix -> [|| FromIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] - Case p rhs def -> [|| Case $$(liftE p) $$(liftList (\(t,c) -> [|| (t, $$(liftE c)) ||]) rhs) $$(liftMaybe liftE def) ||] + Case pR p rhs def -> [|| Case $$(liftTagType pR) $$(liftE p) $$(liftList (\(t,c) -> [|| ($$(liftTag pR t), $$(liftE c)) ||]) rhs) $$(liftMaybe liftE def) ||] Cond p t e -> [|| Cond $$(liftE p) $$(liftE t) $$(liftE e) ||] While p f x -> [|| While $$(liftF p) $$(liftF f) $$(liftE x) ||] - PrimConst t -> [|| PrimConst $$(liftPrimConst t) ||] PrimApp f x -> [|| PrimApp $$(liftPrimFun f) $$(liftE x) ||] Index a ix -> [|| Index $$(liftArrayVar a) $$(liftE ix) ||] LinearIndex a ix -> [|| LinearIndex $$(liftArrayVar a) $$(liftE ix) ||] Shape a -> [|| Shape $$(liftArrayVar a) ||] ShapeSize shr ix -> [|| ShapeSize $$(liftShapeR shr) $$(liftE ix) ||] - Coerce t1 t2 e -> [|| Coerce $$(liftScalarType t1) $$(liftScalarType t2) $$(liftE e) ||] + Bitcast t1 t2 e -> [|| Bitcast $$(liftScalarType t1) $$(liftScalarType t2) $$(liftE e) ||] liftELeftHandSide :: ELeftHandSide t env env' -> CodeQ (ELeftHandSide t env env') liftELeftHandSide = liftLeftHandSide liftScalarType @@ -1319,80 +1452,90 @@ liftBoundary ArrayR (Array sh e) -> Boundary aenv (Array sh e) -> CodeQ (Boundary aenv (Array sh e)) -liftBoundary _ Clamp = [|| Clamp ||] -liftBoundary _ Mirror = [|| Mirror ||] -liftBoundary _ Wrap = [|| Wrap ||] -liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||] -liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||] - -liftPrimConst :: PrimConst c -> CodeQ (PrimConst c) -liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] -liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] -liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] +liftBoundary (ArrayR _ tR) = \case + Clamp -> [|| Clamp ||] + Mirror -> [|| Mirror ||] + Wrap -> [|| Wrap ||] + Constant v -> [|| Constant $$(liftElt tR v) ||] + Function f -> [|| Function $$(liftOpenFun f) ||] liftPrimFun :: PrimFun f -> CodeQ (PrimFun f) -liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||] -liftPrimFun (PrimSub t) = [|| PrimSub $$(liftNumType t) ||] -liftPrimFun (PrimMul t) = [|| PrimMul $$(liftNumType t) ||] -liftPrimFun (PrimNeg t) = [|| PrimNeg $$(liftNumType t) ||] -liftPrimFun (PrimAbs t) = [|| PrimAbs $$(liftNumType t) ||] -liftPrimFun (PrimSig t) = [|| PrimSig $$(liftNumType t) ||] -liftPrimFun (PrimQuot t) = [|| PrimQuot $$(liftIntegralType t) ||] -liftPrimFun (PrimRem t) = [|| PrimRem $$(liftIntegralType t) ||] -liftPrimFun (PrimQuotRem t) = [|| PrimQuotRem $$(liftIntegralType t) ||] -liftPrimFun (PrimIDiv t) = [|| PrimIDiv $$(liftIntegralType t) ||] -liftPrimFun (PrimMod t) = [|| PrimMod $$(liftIntegralType t) ||] -liftPrimFun (PrimDivMod t) = [|| PrimDivMod $$(liftIntegralType t) ||] -liftPrimFun (PrimBAnd t) = [|| PrimBAnd $$(liftIntegralType t) ||] -liftPrimFun (PrimBOr t) = [|| PrimBOr $$(liftIntegralType t) ||] -liftPrimFun (PrimBXor t) = [|| PrimBXor $$(liftIntegralType t) ||] -liftPrimFun (PrimBNot t) = [|| PrimBNot $$(liftIntegralType t) ||] -liftPrimFun (PrimBShiftL t) = [|| PrimBShiftL $$(liftIntegralType t) ||] -liftPrimFun (PrimBShiftR t) = [|| PrimBShiftR $$(liftIntegralType t) ||] -liftPrimFun (PrimBRotateL t) = [|| PrimBRotateL $$(liftIntegralType t) ||] -liftPrimFun (PrimBRotateR t) = [|| PrimBRotateR $$(liftIntegralType t) ||] -liftPrimFun (PrimPopCount t) = [|| PrimPopCount $$(liftIntegralType t) ||] -liftPrimFun (PrimCountLeadingZeros t) = [|| PrimCountLeadingZeros $$(liftIntegralType t) ||] -liftPrimFun (PrimCountTrailingZeros t) = [|| PrimCountTrailingZeros $$(liftIntegralType t) ||] -liftPrimFun (PrimFDiv t) = [|| PrimFDiv $$(liftFloatingType t) ||] -liftPrimFun (PrimRecip t) = [|| PrimRecip $$(liftFloatingType t) ||] -liftPrimFun (PrimSin t) = [|| PrimSin $$(liftFloatingType t) ||] -liftPrimFun (PrimCos t) = [|| PrimCos $$(liftFloatingType t) ||] -liftPrimFun (PrimTan t) = [|| PrimTan $$(liftFloatingType t) ||] -liftPrimFun (PrimAsin t) = [|| PrimAsin $$(liftFloatingType t) ||] -liftPrimFun (PrimAcos t) = [|| PrimAcos $$(liftFloatingType t) ||] -liftPrimFun (PrimAtan t) = [|| PrimAtan $$(liftFloatingType t) ||] -liftPrimFun (PrimSinh t) = [|| PrimSinh $$(liftFloatingType t) ||] -liftPrimFun (PrimCosh t) = [|| PrimCosh $$(liftFloatingType t) ||] -liftPrimFun (PrimTanh t) = [|| PrimTanh $$(liftFloatingType t) ||] -liftPrimFun (PrimAsinh t) = [|| PrimAsinh $$(liftFloatingType t) ||] -liftPrimFun (PrimAcosh t) = [|| PrimAcosh $$(liftFloatingType t) ||] -liftPrimFun (PrimAtanh t) = [|| PrimAtanh $$(liftFloatingType t) ||] -liftPrimFun (PrimExpFloating t) = [|| PrimExpFloating $$(liftFloatingType t) ||] -liftPrimFun (PrimSqrt t) = [|| PrimSqrt $$(liftFloatingType t) ||] -liftPrimFun (PrimLog t) = [|| PrimLog $$(liftFloatingType t) ||] -liftPrimFun (PrimFPow t) = [|| PrimFPow $$(liftFloatingType t) ||] -liftPrimFun (PrimLogBase t) = [|| PrimLogBase $$(liftFloatingType t) ||] -liftPrimFun (PrimTruncate ta tb) = [|| PrimTruncate $$(liftFloatingType ta) $$(liftIntegralType tb) ||] -liftPrimFun (PrimRound ta tb) = [|| PrimRound $$(liftFloatingType ta) $$(liftIntegralType tb) ||] -liftPrimFun (PrimFloor ta tb) = [|| PrimFloor $$(liftFloatingType ta) $$(liftIntegralType tb) ||] -liftPrimFun (PrimCeiling ta tb) = [|| PrimCeiling $$(liftFloatingType ta) $$(liftIntegralType tb) ||] -liftPrimFun (PrimIsNaN t) = [|| PrimIsNaN $$(liftFloatingType t) ||] -liftPrimFun (PrimIsInfinite t) = [|| PrimIsInfinite $$(liftFloatingType t) ||] -liftPrimFun (PrimAtan2 t) = [|| PrimAtan2 $$(liftFloatingType t) ||] -liftPrimFun (PrimLt t) = [|| PrimLt $$(liftSingleType t) ||] -liftPrimFun (PrimGt t) = [|| PrimGt $$(liftSingleType t) ||] -liftPrimFun (PrimLtEq t) = [|| PrimLtEq $$(liftSingleType t) ||] -liftPrimFun (PrimGtEq t) = [|| PrimGtEq $$(liftSingleType t) ||] -liftPrimFun (PrimEq t) = [|| PrimEq $$(liftSingleType t) ||] -liftPrimFun (PrimNEq t) = [|| PrimNEq $$(liftSingleType t) ||] -liftPrimFun (PrimMax t) = [|| PrimMax $$(liftSingleType t) ||] -liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||] -liftPrimFun PrimLAnd = [|| PrimLAnd ||] -liftPrimFun PrimLOr = [|| PrimLOr ||] -liftPrimFun PrimLNot = [|| PrimLNot ||] -liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] -liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] +liftPrimFun = \case + PrimAdd t -> [|| PrimAdd $$(liftNumType t) ||] + PrimSub t -> [|| PrimSub $$(liftNumType t) ||] + PrimMul t -> [|| PrimMul $$(liftNumType t) ||] + PrimNeg t -> [|| PrimNeg $$(liftNumType t) ||] + PrimAbs t -> [|| PrimAbs $$(liftNumType t) ||] + PrimSig t -> [|| PrimSig $$(liftNumType t) ||] + PrimVAdd t -> [|| PrimVAdd $$(liftNumType t) ||] + PrimVMul t -> [|| PrimVMul $$(liftNumType t) ||] + PrimQuot t -> [|| PrimQuot $$(liftIntegralType t) ||] + PrimRem t -> [|| PrimRem $$(liftIntegralType t) ||] + PrimQuotRem t -> [|| PrimQuotRem $$(liftIntegralType t) ||] + PrimIDiv t -> [|| PrimIDiv $$(liftIntegralType t) ||] + PrimMod t -> [|| PrimMod $$(liftIntegralType t) ||] + PrimDivMod t -> [|| PrimDivMod $$(liftIntegralType t) ||] + PrimBAnd t -> [|| PrimBAnd $$(liftIntegralType t) ||] + PrimBOr t -> [|| PrimBOr $$(liftIntegralType t) ||] + PrimBXor t -> [|| PrimBXor $$(liftIntegralType t) ||] + PrimBNot t -> [|| PrimBNot $$(liftIntegralType t) ||] + PrimBShiftL t -> [|| PrimBShiftL $$(liftIntegralType t) ||] + PrimBShiftR t -> [|| PrimBShiftR $$(liftIntegralType t) ||] + PrimBRotateL t -> [|| PrimBRotateL $$(liftIntegralType t) ||] + PrimBRotateR t -> [|| PrimBRotateR $$(liftIntegralType t) ||] + PrimPopCount t -> [|| PrimPopCount $$(liftIntegralType t) ||] + PrimCountLeadingZeros t -> [|| PrimCountLeadingZeros $$(liftIntegralType t) ||] + PrimCountTrailingZeros t -> [|| PrimCountTrailingZeros $$(liftIntegralType t) ||] + PrimBReverse t -> [|| PrimBReverse $$(liftIntegralType t) ||] + PrimBSwap t -> [|| PrimBSwap $$(liftIntegralType t) ||] + PrimVBAnd t -> [|| PrimVBAnd $$(liftIntegralType t) ||] + PrimVBOr t -> [|| PrimVBOr $$(liftIntegralType t) ||] + PrimVBXor t -> [|| PrimVBXor $$(liftIntegralType t) ||] + PrimFDiv t -> [|| PrimFDiv $$(liftFloatingType t) ||] + PrimRecip t -> [|| PrimRecip $$(liftFloatingType t) ||] + PrimSin t -> [|| PrimSin $$(liftFloatingType t) ||] + PrimCos t -> [|| PrimCos $$(liftFloatingType t) ||] + PrimTan t -> [|| PrimTan $$(liftFloatingType t) ||] + PrimAsin t -> [|| PrimAsin $$(liftFloatingType t) ||] + PrimAcos t -> [|| PrimAcos $$(liftFloatingType t) ||] + PrimAtan t -> [|| PrimAtan $$(liftFloatingType t) ||] + PrimSinh t -> [|| PrimSinh $$(liftFloatingType t) ||] + PrimCosh t -> [|| PrimCosh $$(liftFloatingType t) ||] + PrimTanh t -> [|| PrimTanh $$(liftFloatingType t) ||] + PrimAsinh t -> [|| PrimAsinh $$(liftFloatingType t) ||] + PrimAcosh t -> [|| PrimAcosh $$(liftFloatingType t) ||] + PrimAtanh t -> [|| PrimAtanh $$(liftFloatingType t) ||] + PrimExpFloating t -> [|| PrimExpFloating $$(liftFloatingType t) ||] + PrimSqrt t -> [|| PrimSqrt $$(liftFloatingType t) ||] + PrimLog t -> [|| PrimLog $$(liftFloatingType t) ||] + PrimFPow t -> [|| PrimFPow $$(liftFloatingType t) ||] + PrimLogBase t -> [|| PrimLogBase $$(liftFloatingType t) ||] + PrimTruncate ta tb -> [|| PrimTruncate $$(liftFloatingType ta) $$(liftIntegralType tb) ||] + PrimRound ta tb -> [|| PrimRound $$(liftFloatingType ta) $$(liftIntegralType tb) ||] + PrimFloor ta tb -> [|| PrimFloor $$(liftFloatingType ta) $$(liftIntegralType tb) ||] + PrimCeiling ta tb -> [|| PrimCeiling $$(liftFloatingType ta) $$(liftIntegralType tb) ||] + PrimIsNaN t -> [|| PrimIsNaN $$(liftFloatingType t) ||] + PrimIsInfinite t -> [|| PrimIsInfinite $$(liftFloatingType t) ||] + PrimAtan2 t -> [|| PrimAtan2 $$(liftFloatingType t) ||] + PrimLt t -> [|| PrimLt $$(liftScalarType t) ||] + PrimGt t -> [|| PrimGt $$(liftScalarType t) ||] + PrimLtEq t -> [|| PrimLtEq $$(liftScalarType t) ||] + PrimGtEq t -> [|| PrimGtEq $$(liftScalarType t) ||] + PrimEq t -> [|| PrimEq $$(liftScalarType t) ||] + PrimNEq t -> [|| PrimNEq $$(liftScalarType t) ||] + PrimMin t -> [|| PrimMin $$(liftScalarType t) ||] + PrimMax t -> [|| PrimMax $$(liftScalarType t) ||] + PrimVMin t -> [|| PrimVMin $$(liftScalarType t) ||] + PrimVMax t -> [|| PrimVMax $$(liftScalarType t) ||] + PrimLAnd t -> [|| PrimLAnd $$(liftBitType t) ||] + PrimLOr t -> [|| PrimLOr $$(liftBitType t) ||] + PrimLNot t -> [|| PrimLNot $$(liftBitType t) ||] + PrimVLAnd t -> [|| PrimVLAnd $$(liftBitType t) ||] + PrimVLOr t -> [|| PrimVLOr $$(liftBitType t) ||] + PrimFromIntegral ta tb -> [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||] + PrimToFloating ta tb -> [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||] + PrimToBool ta tb -> [|| PrimToBool $$(liftIntegralType ta) $$(liftBitType tb) ||] + PrimFromBool ta tb -> [|| PrimFromBool $$(liftBitType ta) $$(liftIntegralType tb) ||] formatDirection :: Format r (Direction -> r) @@ -1406,6 +1549,7 @@ formatPreAccOp = later $ \case Avar (Var _ ix) -> bformat ("Avar a" % int) (idxToInt ix) Use aR a -> bformat ("Use " % string) (showArrayShort 5 (showsElt (arrayRtype aR)) aR a) Atrace{} -> "Atrace" + Acoerce{} -> "Acoerce" Apply{} -> "Apply" Aforeign{} -> "Aforeign" Acond{} -> "Acond" @@ -1438,8 +1582,10 @@ formatExpOp = later $ \case Foreign{} -> "Foreign" Pair{} -> "Pair" Nil{} -> "Nil" - VecPack{} -> "VecPack" - VecUnpack{} -> "VecUnpack" + Insert{} -> "Insert" + Extract{} -> "Extract" + Shuffle{} -> "Shuffle" + Select{} -> "Select" IndexSlice{} -> "IndexSlice" IndexFull{} -> "IndexFull" ToIndex{} -> "ToIndex" @@ -1447,11 +1593,10 @@ formatExpOp = later $ \case Case{} -> "Case" Cond{} -> "Cond" While{} -> "While" - PrimConst{} -> "PrimConst" PrimApp{} -> "PrimApp" Index{} -> "Index" LinearIndex{} -> "LinearIndex" Shape{} -> "Shape" ShapeSize{} -> "ShapeSize" - Coerce{} -> "Coerce" + Bitcast{} -> "Coerce" diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 7e0fe88c9..51e959da3 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -44,11 +44,13 @@ import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Slice import Data.Array.Accelerate.Representation.Stencil +import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type import Data.Primitive.Vec import Crypto.Hash.XKCP +import Foreign.C.Types import Data.ByteString.Builder import Data.ByteString.Builder.Extra import Data.ByteString.Short.Internal ( ShortByteString(..) ) @@ -58,6 +60,8 @@ import System.Mem.StableName ( hashStable import Prelude hiding ( exp ) import qualified Data.Hashable as Hashable +import GHC.TypeLits + -- Hashing -- ------- @@ -170,6 +174,7 @@ encodePreOpenAcc options encodeAcc pacc = Apair a1 a2 -> intHost $(hashQ "Apair") <> travA a1 <> travA a2 Anil -> intHost $(hashQ "Anil") Atrace (Message _ _ msg) as bs -> intHost $(hashQ "Atrace") <> intHost (Hashable.hash msg) <> travA as <> travA bs + Acoerce _ bR a -> intHost $(hashQ "Acoerce") <> encodeTypeR bR <> travA a Apply _ f a -> intHost $(hashQ "Apply") <> travAF f <> travA a Aforeign _ _ f a -> intHost $(hashQ "Aforeign") <> travAF f <> travA a Use repr a -> intHost $(hashQ "Use") <> encodeArrayType repr <> deep (encodeArray a) @@ -263,7 +268,7 @@ encodeArraysType :: ArraysR arrs -> Builder encodeArraysType = encodeTupR encodeArrayType encodeShapeR :: ShapeR sh -> Builder -encodeShapeR = intHost . rank +encodeShapeR = int64Host . rank encodePreOpenAfun :: forall acc aenv f. @@ -288,13 +293,13 @@ encodeBoundary encodeBoundary _ Wrap = intHost $(hashQ "Wrap") encodeBoundary _ Clamp = intHost $(hashQ "Clamp") encodeBoundary _ Mirror = intHost $(hashQ "Mirror") -encodeBoundary tp (Constant c) = intHost $(hashQ "Constant") <> encodeConst tp c +encodeBoundary tR (Constant c) = intHost $(hashQ "Constant") <> encodeConst tR c encodeBoundary _ (Function f) = intHost $(hashQ "Function") <> encodeOpenFun f encodeSliceIndex :: SliceIndex slix sl co sh -> Builder -encodeSliceIndex SliceNil = intHost $(hashQ "SliceNil") -encodeSliceIndex (SliceAll r) = intHost $(hashQ "SliceAll") <> encodeSliceIndex r -encodeSliceIndex (SliceFixed r) = intHost $(hashQ "sliceFixed") <> encodeSliceIndex r +encodeSliceIndex SliceNil = intHost $(hashQ "SliceNil") +encodeSliceIndex (SliceAll r) = intHost $(hashQ "SliceAll") <> encodeSliceIndex r +encodeSliceIndex (SliceFixed r) = intHost $(hashQ "sliceFixed") <> encodeSliceIndex r -- Scalar expressions @@ -318,25 +323,26 @@ encodeOpenExp exp = Evar (Var tp ix) -> intHost $(hashQ "Evar") <> encodeScalarType tp <> encodeIdx ix Nil -> intHost $(hashQ "Nil") Pair e1 e2 -> intHost $(hashQ "Pair") <> travE e1 <> travE e2 - VecPack _ e -> intHost $(hashQ "VecPack") <> travE e - VecUnpack _ e -> intHost $(hashQ "VecUnpack") <> travE e + Extract vR iR v i -> intHost $(hashQ "Extract") <> encodeScalarType vR <> encodeSingleIntegralType iR <> travE v <> travE i + Insert vR iR v i x -> intHost $(hashQ "Insert") <> encodeScalarType vR <> encodeSingleIntegralType iR <> travE v <> travE i <> travE x + Shuffle eR iR x y i -> intHost $(hashQ "Shuffle") <> encodeScalarType eR <> encodeSingleIntegralType iR <> travE x <> travE y <> travE i + Select eR m x y -> intHost $(hashQ "Select") <> encodeScalarType eR <> travE m <>travE x <> travE y Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec IndexFull spec ix sl -> intHost $(hashQ "IndexFull") <> travE ix <> travE sl <> encodeSliceIndex spec ToIndex _ sh i -> intHost $(hashQ "ToIndex") <> travE sh <> travE i FromIndex _ sh i -> intHost $(hashQ "FromIndex") <> travE sh <> travE i - Case e rhs def -> intHost $(hashQ "Case") <> travE e <> mconcat [ word8 t <> travE c | (t,c) <- rhs ] <> encodeMaybe travE def + Case eR e rhs def -> intHost $(hashQ "Case") <> encodeTagType eR <> travE e <> mconcat [ encodeTagConst eR t <> travE c | (t,c) <- rhs ] <> encodeMaybe travE def Cond c t e -> intHost $(hashQ "Cond") <> travE c <> travE t <> travE e While p f x -> intHost $(hashQ "While") <> travF p <> travF f <> travE x PrimApp f x -> intHost $(hashQ "PrimApp") <> encodePrimFun f <> travE x - PrimConst c -> intHost $(hashQ "PrimConst") <> encodePrimConst c Index a ix -> intHost $(hashQ "Index") <> encodeArrayVar a <> travE ix LinearIndex a ix -> intHost $(hashQ "LinearIndex") <> encodeArrayVar a <> travE ix Shape a -> intHost $(hashQ "Shape") <> encodeArrayVar a ShapeSize _ sh -> intHost $(hashQ "ShapeSize") <> travE sh Foreign _ _ f e -> intHost $(hashQ "Foreign") <> travF f <> travE e - Coerce _ tp e -> intHost $(hashQ "Coerce") <> encodeScalarType tp <> travE e + Bitcast _ tp e -> intHost $(hashQ "Bitcast") <> encodeScalarType tp <> travE e encodeArrayVar :: ArrayVar aenv a -> Builder encodeArrayVar (Var repr v) = encodeArrayType repr <> encodeIdx v @@ -355,40 +361,49 @@ encodeConst (TupRsingle t) c = encodeScalarConst t c encodeConst (TupRpair ta tb) (a,b) = intHost $(hashQ "pair") <> encodeConst ta a <> encodeConst tb b encodeScalarConst :: ScalarType t -> t -> Builder -encodeScalarConst (SingleScalarType t) = encodeSingleConst t -encodeScalarConst (VectorScalarType t) = encodeVectorConst t - -encodeSingleConst :: SingleType t -> t -> Builder -encodeSingleConst (NumSingleType t) = encodeNumConst t +encodeScalarConst (NumScalarType t) = encodeNumConst t +encodeScalarConst (BitScalarType t) = encodeBitConst t -encodeVectorConst :: VectorType (Vec n t) -> Vec n t -> Builder -encodeVectorConst (VectorType n t) (Vec ba#) = intHost $(hashQ "Vec") <> intHost n <> encodeSingleType t <> shortByteString (SBS ba#) +encodeBitConst :: BitType t -> t -> Builder +encodeBitConst TypeBit (Bit False) = intHost $(hashQ "Bit") <> int8 0 +encodeBitConst TypeBit (Bit True) = intHost $(hashQ "Bit") <> int8 1 +encodeBitConst (TypeMask n) (Vec ba#) = intHost $(hashQ "BitMask") <> int8 (fromIntegral (natVal' n)) <> shortByteString (SBS ba#) encodeNumConst :: NumType t -> t -> Builder encodeNumConst (IntegralNumType t) = encodeIntegralConst t encodeNumConst (FloatingNumType t) = encodeFloatingConst t encodeIntegralConst :: IntegralType t -> t -> Builder -encodeIntegralConst TypeInt{} x = intHost $(hashQ "Int") <> intHost x -encodeIntegralConst TypeInt8{} x = intHost $(hashQ "Int8") <> int8 x -encodeIntegralConst TypeInt16{} x = intHost $(hashQ "Int16") <> int16Host x -encodeIntegralConst TypeInt32{} x = intHost $(hashQ "Int32") <> int32Host x -encodeIntegralConst TypeInt64{} x = intHost $(hashQ "Int64") <> int64Host x -encodeIntegralConst TypeWord{} x = intHost $(hashQ "Word") <> wordHost x -encodeIntegralConst TypeWord8{} x = intHost $(hashQ "Word8") <> word8 x -encodeIntegralConst TypeWord16{} x = intHost $(hashQ "Word16") <> word16Host x -encodeIntegralConst TypeWord32{} x = intHost $(hashQ "Word32") <> word32Host x -encodeIntegralConst TypeWord64{} x = intHost $(hashQ "Word64") <> word64Host x +encodeIntegralConst (SingleIntegralType t) x = encodeSingleIntegralConst t x +encodeIntegralConst (VectorIntegralType n t) (Vec ba#) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleIntegralType t <> shortByteString (SBS ba#) + +encodeSingleIntegralConst :: SingleIntegralType t -> t -> Builder +encodeSingleIntegralConst TypeInt8 x = intHost $(hashQ "Int8") <> int8 x +encodeSingleIntegralConst TypeInt16 x = intHost $(hashQ "Int16") <> int16Host x +encodeSingleIntegralConst TypeInt32 x = intHost $(hashQ "Int32") <> int32Host x +encodeSingleIntegralConst TypeInt64 x = intHost $(hashQ "Int64") <> int64Host x +encodeSingleIntegralConst TypeInt128 (Int128 x y) = intHost $(hashQ "Int128") <> word64Host x <> word64Host y +encodeSingleIntegralConst TypeWord8 x = intHost $(hashQ "Word8") <> word8 x +encodeSingleIntegralConst TypeWord16 x = intHost $(hashQ "Word16") <> word16Host x +encodeSingleIntegralConst TypeWord32 x = intHost $(hashQ "Word32") <> word32Host x +encodeSingleIntegralConst TypeWord64 x = intHost $(hashQ "Word64") <> word64Host x +encodeSingleIntegralConst TypeWord128 (Word128 x y) = intHost $(hashQ "Word128") <> word64Host x <> word64Host y encodeFloatingConst :: FloatingType t -> t -> Builder -encodeFloatingConst TypeHalf{} (Half (CUShort x)) = intHost $(hashQ "Half") <> word16Host x -encodeFloatingConst TypeFloat{} x = intHost $(hashQ "Float") <> floatHost x -encodeFloatingConst TypeDouble{} x = intHost $(hashQ "Double") <> doubleHost x +encodeFloatingConst (SingleFloatingType t) x = encodeSingleFloatingConst t x +encodeFloatingConst (VectorFloatingType n t) (Vec ba#) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleFloatingType t <> shortByteString (SBS ba#) -encodePrimConst :: PrimConst c -> Builder -encodePrimConst (PrimMinBound t) = intHost $(hashQ "PrimMinBound") <> encodeBoundedType t -encodePrimConst (PrimMaxBound t) = intHost $(hashQ "PrimMaxBound") <> encodeBoundedType t -encodePrimConst (PrimPi t) = intHost $(hashQ "PrimPi") <> encodeFloatingType t +encodeSingleFloatingConst :: SingleFloatingType t -> t -> Builder +encodeSingleFloatingConst TypeFloat16 (Half (CUShort x)) = intHost $(hashQ "Half") <> word16Host x +encodeSingleFloatingConst TypeFloat32 x = intHost $(hashQ "Float") <> floatHost x +encodeSingleFloatingConst TypeFloat64 x = intHost $(hashQ "Double") <> doubleHost x +encodeSingleFloatingConst TypeFloat128 (Float128 x y) = intHost $(hashQ "Float128") <> word64Host x <> word64Host y + +encodeTagConst :: TagType t -> t -> Builder +encodeTagConst TagBit (Bit False) = intHost $(hashQ "Bit") <> int8 0 +encodeTagConst TagBit (Bit True) = intHost $(hashQ "Bit") <> int8 1 +encodeTagConst TagWord8 x = intHost $(hashQ "Tag8") <> word8 x +encodeTagConst TagWord16 x = intHost $(hashQ "Tag16") <> word16Host x encodePrimFun :: PrimFun f -> Builder encodePrimFun (PrimAdd a) = intHost $(hashQ "PrimAdd") <> encodeNumType a @@ -397,6 +412,8 @@ encodePrimFun (PrimMul a) = intHost $(hashQ "PrimMul") encodePrimFun (PrimNeg a) = intHost $(hashQ "PrimNeg") <> encodeNumType a encodePrimFun (PrimAbs a) = intHost $(hashQ "PrimAbs") <> encodeNumType a encodePrimFun (PrimSig a) = intHost $(hashQ "PrimSig") <> encodeNumType a +encodePrimFun (PrimVAdd a) = intHost $(hashQ "PrimVAdd") <> encodeNumType a +encodePrimFun (PrimVMul a) = intHost $(hashQ "PrimVMul") <> encodeNumType a encodePrimFun (PrimQuot a) = intHost $(hashQ "PrimQuot") <> encodeIntegralType a encodePrimFun (PrimRem a) = intHost $(hashQ "PrimRem") <> encodeIntegralType a encodePrimFun (PrimQuotRem a) = intHost $(hashQ "PrimQuotRem") <> encodeIntegralType a @@ -414,6 +431,11 @@ encodePrimFun (PrimBRotateR a) = intHost $(hashQ "PrimBRotateR") encodePrimFun (PrimPopCount a) = intHost $(hashQ "PrimPopCount") <> encodeIntegralType a encodePrimFun (PrimCountLeadingZeros a) = intHost $(hashQ "PrimCountLeadingZeros") <> encodeIntegralType a encodePrimFun (PrimCountTrailingZeros a) = intHost $(hashQ "PrimCountTrailingZeros") <> encodeIntegralType a +encodePrimFun (PrimBReverse a) = intHost $(hashQ "PrimBReverse") <> encodeIntegralType a +encodePrimFun (PrimBSwap a) = intHost $(hashQ "PrimBSwap") <> encodeIntegralType a +encodePrimFun (PrimVBAnd a) = intHost $(hashQ "PrimVBAnd") <> encodeIntegralType a +encodePrimFun (PrimVBOr a) = intHost $(hashQ "PrimVBOr") <> encodeIntegralType a +encodePrimFun (PrimVBXor a) = intHost $(hashQ "PrimVBXor") <> encodeIntegralType a encodePrimFun (PrimFDiv a) = intHost $(hashQ "PrimFDiv") <> encodeFloatingType a encodePrimFun (PrimRecip a) = intHost $(hashQ "PrimRecip") <> encodeFloatingType a encodePrimFun (PrimSin a) = intHost $(hashQ "PrimSin") <> encodeFloatingType a @@ -440,20 +462,25 @@ encodePrimFun (PrimFloor a b) = intHost $(hashQ "PrimFloor") encodePrimFun (PrimCeiling a b) = intHost $(hashQ "PrimCeiling") <> encodeFloatingType a <> encodeIntegralType b encodePrimFun (PrimIsNaN a) = intHost $(hashQ "PrimIsNaN") <> encodeFloatingType a encodePrimFun (PrimIsInfinite a) = intHost $(hashQ "PrimIsInfinite") <> encodeFloatingType a -encodePrimFun (PrimLt a) = intHost $(hashQ "PrimLt") <> encodeSingleType a -encodePrimFun (PrimGt a) = intHost $(hashQ "PrimGt") <> encodeSingleType a -encodePrimFun (PrimLtEq a) = intHost $(hashQ "PrimLtEq") <> encodeSingleType a -encodePrimFun (PrimGtEq a) = intHost $(hashQ "PrimGtEq") <> encodeSingleType a -encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") <> encodeSingleType a -encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a -encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a -encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a +encodePrimFun (PrimLt a) = intHost $(hashQ "PrimLt") <> encodeScalarType a +encodePrimFun (PrimGt a) = intHost $(hashQ "PrimGt") <> encodeScalarType a +encodePrimFun (PrimLtEq a) = intHost $(hashQ "PrimLtEq") <> encodeScalarType a +encodePrimFun (PrimGtEq a) = intHost $(hashQ "PrimGtEq") <> encodeScalarType a +encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq") <> encodeScalarType a +encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeScalarType a +encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeScalarType a +encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeScalarType a +encodePrimFun (PrimVMin a) = intHost $(hashQ "PrimVMin") <> encodeScalarType a +encodePrimFun (PrimVMax a) = intHost $(hashQ "PrimVMax") <> encodeScalarType a +encodePrimFun (PrimLAnd a) = intHost $(hashQ "PrimLAnd") <> encodeBitType a +encodePrimFun (PrimLOr a) = intHost $(hashQ "PrimLOr") <> encodeBitType a +encodePrimFun (PrimLNot a) = intHost $(hashQ "PrimLNot") <> encodeBitType a +encodePrimFun (PrimVLAnd a) = intHost $(hashQ "PrimVLAnd") <> encodeBitType a +encodePrimFun (PrimVLOr a) = intHost $(hashQ "PrimVLOr") <> encodeBitType a encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b -encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b -encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd") -encodePrimFun PrimLOr = intHost $(hashQ "PrimLOr") -encodePrimFun PrimLNot = intHost $(hashQ "PrimLNot") - +encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b +encodePrimFun (PrimToBool a b) = intHost $(hashQ "PrimToBool") <> encodeIntegralType a <> encodeBitType b +encodePrimFun (PrimFromBool a b) = intHost $(hashQ "PrimFromBool") <> encodeBitType a <> encodeIntegralType b encodeTypeR :: TypeR t -> Builder encodeTypeR TupRunit = intHost $(hashQ "TupRunit") @@ -467,38 +494,39 @@ depthTypeR TupRsingle{} = 1 depthTypeR (TupRpair a b) = depthTypeR a + depthTypeR b encodeScalarType :: ScalarType t -> Builder -encodeScalarType (SingleScalarType t) = intHost $(hashQ "SingleScalarType") <> encodeSingleType t -encodeScalarType (VectorScalarType t) = intHost $(hashQ "VectorScalarType") <> encodeVectorType t +encodeScalarType (NumScalarType t) = encodeNumType t +encodeScalarType (BitScalarType t) = encodeBitType t -encodeSingleType :: SingleType t -> Builder -encodeSingleType (NumSingleType t) = intHost $(hashQ "NumSingleType") <> encodeNumType t - -encodeVectorType :: VectorType (Vec n t) -> Builder -encodeVectorType (VectorType n t) = intHost $(hashQ "VectorType") <> intHost n <> encodeSingleType t - -encodeBoundedType :: BoundedType t -> Builder -encodeBoundedType (IntegralBoundedType t) = intHost $(hashQ "IntegralBoundedType") <> encodeIntegralType t +encodeBitType :: BitType t -> Builder +encodeBitType TypeBit = intHost $(hashQ "Bit") +encodeBitType (TypeMask n) = intHost $(hashQ "BitMask") <> int8 (fromIntegral (natVal' n)) encodeNumType :: NumType t -> Builder encodeNumType (IntegralNumType t) = intHost $(hashQ "IntegralNumType") <> encodeIntegralType t encodeNumType (FloatingNumType t) = intHost $(hashQ "FloatingNumType") <> encodeFloatingType t encodeIntegralType :: IntegralType t -> Builder -encodeIntegralType TypeInt{} = intHost $(hashQ "Int") -encodeIntegralType TypeInt8{} = intHost $(hashQ "Int8") -encodeIntegralType TypeInt16{} = intHost $(hashQ "Int16") -encodeIntegralType TypeInt32{} = intHost $(hashQ "Int32") -encodeIntegralType TypeInt64{} = intHost $(hashQ "Int64") -encodeIntegralType TypeWord{} = intHost $(hashQ "Word") -encodeIntegralType TypeWord8{} = intHost $(hashQ "Word8") -encodeIntegralType TypeWord16{} = intHost $(hashQ "Word16") -encodeIntegralType TypeWord32{} = intHost $(hashQ "Word32") -encodeIntegralType TypeWord64{} = intHost $(hashQ "Word64") +encodeIntegralType (SingleIntegralType t) = encodeSingleIntegralType t +encodeIntegralType (VectorIntegralType n t) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleIntegralType t + +encodeSingleIntegralType :: SingleIntegralType t -> Builder +encodeSingleIntegralType (TypeInt n) = intHost $(hashQ "Int") <> intHost n +encodeSingleIntegralType (TypeWord n) = intHost $(hashQ "Word") <> intHost n encodeFloatingType :: FloatingType t -> Builder -encodeFloatingType TypeHalf{} = intHost $(hashQ "Half") -encodeFloatingType TypeFloat{} = intHost $(hashQ "Float") -encodeFloatingType TypeDouble{} = intHost $(hashQ "Double") +encodeFloatingType (SingleFloatingType t) = encodeSingleFloatingType t +encodeFloatingType (VectorFloatingType n t) = intHost $(hashQ "Vec") <> int8 (fromIntegral (natVal' n)) <> encodeSingleFloatingType t + +encodeSingleFloatingType :: SingleFloatingType t -> Builder +encodeSingleFloatingType TypeFloat16 = intHost $(hashQ "Half") +encodeSingleFloatingType TypeFloat32 = intHost $(hashQ "Float") +encodeSingleFloatingType TypeFloat64 = intHost $(hashQ "Double") +encodeSingleFloatingType TypeFloat128 = intHost $(hashQ "Float128") + +encodeTagType :: TagType t -> Builder +encodeTagType TagBit = intHost $(hashQ "TagBit") +encodeTagType TagWord8 = intHost $(hashQ "TagWord8") +encodeTagType TagWord16 = intHost $(hashQ "TagWord16") encodeMaybe :: (a -> Builder) -> Maybe a -> Builder encodeMaybe _ Nothing = intHost $(hashQ "Nothing") diff --git a/src/Data/Array/Accelerate/Analysis/Match.hs b/src/Data/Array/Accelerate/Analysis/Match.hs index 98a0818a4..11e1dc7ab 100644 --- a/src/Data/Array/Accelerate/Analysis/Match.hs +++ b/src/Data/Array/Accelerate/Analysis/Match.hs @@ -1,6 +1,8 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -30,8 +32,14 @@ module Data.Array.Accelerate.Analysis.Match ( -- auxiliary matchIdx, matchVar, matchVars, matchArrayR, matchArraysR, matchTypeR, matchShapeR, - matchShapeType, matchIntegralType, matchFloatingType, matchNumType, matchScalarType, - matchLeftHandSide, matchALeftHandSide, matchELeftHandSide, matchSingleType, matchTupR + matchShapeType, + matchIntegralType, matchSingleIntegralType, + matchFloatingType, matchSingleFloatingType, + matchNumType, + matchBitType, + matchTagType, + matchScalarType, + matchLeftHandSide, matchALeftHandSide, matchELeftHandSide, matchTupR, ) where @@ -40,21 +48,24 @@ import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Analysis.Hash -import Data.Array.Accelerate.Representation.Array -import Data.Array.Accelerate.Representation.Shape -import Data.Array.Accelerate.Representation.Slice +import Data.Array.Accelerate.Interpreter.Arithmetic +import Data.Array.Accelerate.Representation.Array ( Array(..), ArraysR, ArrayR(..) ) +import Data.Array.Accelerate.Representation.Shape ( ShapeR(..) ) +import Data.Array.Accelerate.Representation.Slice ( SliceIndex(..) ) import Data.Array.Accelerate.Representation.Stencil +import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type -import Data.Primitive.Vec -import qualified Data.Array.Accelerate.Sugar.Shape as Sugar +import qualified Data.Array.Accelerate.Sugar.Shape as Sugar import Data.Maybe import Data.Typeable -import Unsafe.Coerce ( unsafeCoerce ) -import System.IO.Unsafe ( unsafePerformIO ) +import Unsafe.Coerce ( unsafeCoerce ) +import System.IO.Unsafe ( unsafePerformIO ) import System.Mem.StableName -import Prelude hiding ( exp ) +import Prelude hiding ( exp ) + +import GHC.TypeLits.Extra -- The type of matching array computations @@ -455,24 +466,42 @@ matchOpenExp (Foreign _ ff1 f1 e1) (Foreign _ ff2 f2 e2) , Just Refl <- matchOpenFun f1 f2 = Just Refl -matchOpenExp (Const t1 c1) (Const t2 c2) - | Just Refl <- matchScalarType t1 t2 - , matchConst (TupRsingle t1) c1 c2 +matchOpenExp (Pair a1 b1) (Pair a2 b2) + | Just Refl <- matchOpenExp a1 a2 + , Just Refl <- matchOpenExp b1 b2 = Just Refl -matchOpenExp (Undef t1) (Undef t2) = matchScalarType t1 t2 +matchOpenExp Nil Nil + = Just Refl -matchOpenExp (Coerce _ t1 e1) (Coerce _ t2 e2) - | Just Refl <- matchScalarType t1 t2 - , Just Refl <- matchOpenExp e1 e2 +matchOpenExp (Extract vR1 iR1 v1 i1) (Extract vR2 iR2 v2 i2) + | Just Refl <- matchScalarType vR1 vR2 + , Just Refl <- matchSingleIntegralType iR1 iR2 + , Just Refl <- matchOpenExp v1 v2 + , Just Refl <- matchOpenExp i1 i2 = Just Refl -matchOpenExp (Pair a1 b1) (Pair a2 b2) - | Just Refl <- matchOpenExp a1 a2 - , Just Refl <- matchOpenExp b1 b2 +matchOpenExp (Insert vR1 iR1 v1 i1 x1) (Insert vR2 iR2 v2 i2 x2) + | Just Refl <- matchScalarType vR1 vR2 + , Just Refl <- matchSingleIntegralType iR1 iR2 + , Just Refl <- matchOpenExp v1 v2 + , Just Refl <- matchOpenExp i1 i2 + , Just Refl <- matchOpenExp x1 x2 = Just Refl -matchOpenExp Nil Nil +matchOpenExp (Shuffle eR1 iR1 x1 y1 i1) (Shuffle eR2 iR2 x2 y2 i2) + | Just Refl <- matchScalarType eR1 eR2 + , Just Refl <- matchSingleIntegralType iR1 iR2 + , Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchOpenExp y1 y2 + , Just Refl <- matchOpenExp i1 i2 + = Just Refl + +matchOpenExp (Select eR1 p1 x1 y1) (Select eR2 p2 x2 y2) + | Just Refl <- matchScalarType eR1 eR2 + , Just Refl <- matchOpenExp p1 p2 + , Just Refl <- matchOpenExp x1 x2 + , Just Refl <- matchOpenExp y1 y2 = Just Refl matchOpenExp (IndexSlice sliceIndex1 ix1 sh1) (IndexSlice sliceIndex2 ix2 sh2) @@ -497,6 +526,30 @@ matchOpenExp (FromIndex _ sh1 i1) (FromIndex _ sh2 i2) , Just Refl <- matchOpenExp sh1 sh2 = Just Refl +matchOpenExp (Case eR1 e1 rhs1 def1) (Case eR2 e2 rhs2 def2) + | Just Refl <- matchTagType eR1 eR2 + , Just Refl <- matchOpenExp e1 e2 + , Just Refl <- matchCaseEqs eR1 rhs1 rhs2 + , Just Refl <- matchCaseDef def1 def2 + = Just Refl + where + matchCaseEqs :: TagType tag -> [(tag, OpenExp env aenv a)] -> [(tag, OpenExp env aenv b)] -> Maybe (a :~: b) + matchCaseEqs _ [] [] + = Just (unsafeCoerce Refl) + matchCaseEqs tR ((s,x):xs) ((t,y):ys) + | TagDict <- tagDict tR + , s == t + , Just Refl <- matchOpenExp x y + , Just Refl <- matchCaseEqs tR xs ys + = Just Refl + matchCaseEqs _ _ _ + = Nothing + + matchCaseDef :: Maybe (OpenExp env aenv a) -> Maybe (OpenExp env aenv b) -> Maybe (a :~: b) + matchCaseDef Nothing Nothing = Just (unsafeCoerce Refl) + matchCaseDef (Just x) (Just y) = matchOpenExp x y + matchCaseDef _ _ = Nothing + matchOpenExp (Cond p1 t1 e1) (Cond p2 t2 e2) | Just Refl <- matchOpenExp p1 p2 , Just Refl <- matchOpenExp t1 t2 @@ -509,8 +562,10 @@ matchOpenExp (While p1 f1 x1) (While p2 f2 x2) , Just Refl <- matchOpenFun f1 f2 = Just Refl -matchOpenExp (PrimConst c1) (PrimConst c2) - = matchPrimConst c1 c2 +matchOpenExp (Const t1 c1) (Const t2 c2) + | Just Refl <- matchScalarType t1 t2 + , matchConst (TupRsingle t1) c1 c2 + = Just Refl matchOpenExp (PrimApp f1 x1) (PrimApp f2 x2) | Just x1' <- commutes f1 x1 @@ -541,6 +596,13 @@ matchOpenExp (ShapeSize _ sh1) (ShapeSize _ sh2) | Just Refl <- matchOpenExp sh1 sh2 = Just Refl +matchOpenExp (Undef t1) (Undef t2) = matchScalarType t1 t2 + +matchOpenExp (Bitcast _ t1 e1) (Bitcast _ t2 e2) + | Just Refl <- matchScalarType t1 t2 + , Just Refl <- matchOpenExp e1 e2 + = Just Refl + matchOpenExp _ _ = Nothing @@ -568,18 +630,41 @@ matchConst (TupRsingle ty) a b = evalEq ty (a,b) matchConst (TupRpair ta tb) (a1,b1) (a2,b2) = matchConst ta a1 a2 && matchConst tb b1 b2 evalEq :: ScalarType a -> (a, a) -> Bool -evalEq (SingleScalarType t) = evalEqSingle t -evalEq (VectorScalarType t) = evalEqVector t - -evalEqSingle :: SingleType a -> (a, a) -> Bool -evalEqSingle (NumSingleType t) = evalEqNum t - -evalEqVector :: VectorType a -> (a, a) -> Bool -evalEqVector VectorType{} = uncurry (==) - -evalEqNum :: NumType a -> (a, a) -> Bool -evalEqNum (IntegralNumType t) | IntegralDict <- integralDict t = uncurry (==) -evalEqNum (FloatingNumType t) | FloatingDict <- floatingDict t = uncurry (==) +evalEq t x = unBit $ scalar t (eq t x) + where + scalar :: ScalarType s -> BitOrMask s -> PrimBool + scalar (NumScalarType s) = num s + scalar (BitScalarType s) = bit s + + bit :: BitType s -> BitOrMask s -> PrimBool + bit TypeBit = id + bit TypeMask{} = vland bitType + + num :: NumType s -> BitOrMask s -> PrimBool + num (IntegralNumType s) = integral s + num (FloatingNumType s) = floating s + + integral :: IntegralType s -> BitOrMask s -> PrimBool + integral (VectorIntegralType n _) = vland (TypeMask n) + integral (SingleIntegralType s) = case s of + TypeInt8 -> id + TypeInt16 -> id + TypeInt32 -> id + TypeInt64 -> id + TypeInt128 -> id + TypeWord8 -> id + TypeWord16 -> id + TypeWord32 -> id + TypeWord64 -> id + TypeWord128 -> id + + floating :: FloatingType s -> BitOrMask s -> PrimBool + floating (VectorFloatingType n _) = vland (TypeMask n) + floating (SingleFloatingType s) = case s of + TypeFloat16 -> id + TypeFloat32 -> id + TypeFloat64 -> id + TypeFloat128 -> id -- Environment projection indices @@ -622,14 +707,8 @@ matchSliceIndex (SliceFixed sl1) (SliceFixed sl2) matchSliceIndex _ _ = Nothing --- Primitive constants and functions --- -matchPrimConst :: PrimConst s -> PrimConst t -> Maybe (s :~: t) -matchPrimConst (PrimMinBound s) (PrimMinBound t) = matchBoundedType s t -matchPrimConst (PrimMaxBound s) (PrimMaxBound t) = matchBoundedType s t -matchPrimConst (PrimPi s) (PrimPi t) = matchFloatingType s t -matchPrimConst _ _ = Nothing - +-- Primitive functions +-- ------------------- -- Covariant function matching -- @@ -692,11 +771,13 @@ matchPrimFun (PrimEq _) (PrimEq _) = Just Refl matchPrimFun (PrimNEq _) (PrimNEq _) = Just Refl matchPrimFun (PrimMax _) (PrimMax _) = Just Refl matchPrimFun (PrimMin _) (PrimMin _) = Just Refl +matchPrimFun (PrimLAnd s) (PrimLAnd t) = matchBitType s t +matchPrimFun (PrimLOr s) (PrimLOr t) = matchBitType s t +matchPrimFun (PrimLNot s) (PrimLNot t) = matchBitType s t matchPrimFun (PrimFromIntegral _ s) (PrimFromIntegral _ t) = matchNumType s t matchPrimFun (PrimToFloating _ s) (PrimToFloating _ t) = matchFloatingType s t -matchPrimFun PrimLAnd PrimLAnd = Just Refl -matchPrimFun PrimLOr PrimLOr = Just Refl -matchPrimFun PrimLNot PrimLNot = Just Refl +matchPrimFun (PrimToBool _ s) (PrimToBool _ t) = matchBitType s t +matchPrimFun (PrimFromBool _ s) (PrimFromBool _ t) = matchIntegralType s t matchPrimFun _ _ = Nothing @@ -757,34 +838,36 @@ matchPrimFun' (PrimIsNaN s) (PrimIsNaN t) = matchFloat matchPrimFun' (PrimIsInfinite s) (PrimIsInfinite t) = matchFloatingType s t matchPrimFun' (PrimMax _) (PrimMax _) = Just Refl matchPrimFun' (PrimMin _) (PrimMin _) = Just Refl +matchPrimFun' (PrimLAnd _) (PrimLAnd _) = Just Refl +matchPrimFun' (PrimLOr _) (PrimLOr _) = Just Refl +matchPrimFun' (PrimLNot _) (PrimLNot _) = Just Refl matchPrimFun' (PrimFromIntegral s _) (PrimFromIntegral t _) = matchIntegralType s t matchPrimFun' (PrimToFloating s _) (PrimToFloating t _) = matchNumType s t -matchPrimFun' PrimLAnd PrimLAnd = Just Refl -matchPrimFun' PrimLOr PrimLOr = Just Refl -matchPrimFun' PrimLNot PrimLNot = Just Refl +matchPrimFun' (PrimToBool s _) (PrimToBool t _) = matchIntegralType s t +matchPrimFun' (PrimFromBool s _) (PrimFromBool t _) = matchBitType s t matchPrimFun' (PrimLt s) (PrimLt t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimGt s) (PrimGt t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimLtEq s) (PrimLtEq t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimGtEq s) (PrimGtEq t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimEq s) (PrimEq t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' (PrimNEq s) (PrimNEq t) - | Just Refl <- matchSingleType s t + | Just Refl <- matchScalarType s t = Just Refl matchPrimFun' _ _ @@ -831,24 +914,17 @@ matchShapeR _ _ = Nothing -- {-# INLINEABLE matchScalarType #-} matchScalarType :: ScalarType s -> ScalarType t -> Maybe (s :~: t) -matchScalarType (SingleScalarType s) (SingleScalarType t) = matchSingleType s t -matchScalarType (VectorScalarType s) (VectorScalarType t) = matchVectorType s t -matchScalarType _ _ = Nothing - -{-# INLINEABLE matchSingleType #-} -matchSingleType :: SingleType s -> SingleType t -> Maybe (s :~: t) -matchSingleType (NumSingleType s) (NumSingleType t) = matchNumType s t - -{-# INLINEABLE matchVectorType #-} -matchVectorType :: forall m n s t. VectorType (Vec n s) -> VectorType (Vec m t) -> Maybe (Vec n s :~: Vec m t) -matchVectorType (VectorType n s) (VectorType m t) - | Just Refl <- if n == m - then Just (unsafeCoerce Refl :: n :~: m) -- XXX: we don't have an embedded KnownNat constraint, but - else Nothing -- this implementation is the same as 'GHC.TypeLits.sameNat' - , Just Refl <- matchSingleType s t - = Just Refl -matchVectorType _ _ - = Nothing +matchScalarType (NumScalarType s) (NumScalarType t) = matchNumType s t +matchScalarType (BitScalarType s) (BitScalarType t) = matchBitType s t +matchScalarType _ _ = Nothing + +{-# INLINEABLE matchBitType #-} +matchBitType :: BitType s -> BitType t -> Maybe (s :~: t) +matchBitType TypeBit TypeBit = Just Refl +matchBitType (TypeMask n) (TypeMask m) + | Just Refl <- sameNat' n m + = Just Refl +matchBitType _ _ = Nothing {-# INLINEABLE matchNumType #-} matchNumType :: NumType s -> NumType t -> Maybe (s :~: t) @@ -856,30 +932,44 @@ matchNumType (IntegralNumType s) (IntegralNumType t) = matchIntegralType s t matchNumType (FloatingNumType s) (FloatingNumType t) = matchFloatingType s t matchNumType _ _ = Nothing -{-# INLINEABLE matchBoundedType #-} -matchBoundedType :: BoundedType s -> BoundedType t -> Maybe (s :~: t) -matchBoundedType (IntegralBoundedType s) (IntegralBoundedType t) = matchIntegralType s t - {-# INLINEABLE matchIntegralType #-} matchIntegralType :: IntegralType s -> IntegralType t -> Maybe (s :~: t) -matchIntegralType TypeInt TypeInt = Just Refl -matchIntegralType TypeInt8 TypeInt8 = Just Refl -matchIntegralType TypeInt16 TypeInt16 = Just Refl -matchIntegralType TypeInt32 TypeInt32 = Just Refl -matchIntegralType TypeInt64 TypeInt64 = Just Refl -matchIntegralType TypeWord TypeWord = Just Refl -matchIntegralType TypeWord8 TypeWord8 = Just Refl -matchIntegralType TypeWord16 TypeWord16 = Just Refl -matchIntegralType TypeWord32 TypeWord32 = Just Refl -matchIntegralType TypeWord64 TypeWord64 = Just Refl -matchIntegralType _ _ = Nothing +matchIntegralType (SingleIntegralType s) (SingleIntegralType t) = matchSingleIntegralType s t +matchIntegralType (VectorIntegralType n s) (VectorIntegralType m t) + | Just Refl <- sameNat' n m + , Just Refl <- matchSingleIntegralType s t + = Just Refl +matchIntegralType _ _ = Nothing + +{-# INLINEABLE matchSingleIntegralType #-} +matchSingleIntegralType :: SingleIntegralType s -> SingleIntegralType t -> Maybe (s :~: t) +matchSingleIntegralType (TypeInt n) (TypeInt m) | m == n = Just (unsafeCoerce Refl) +matchSingleIntegralType (TypeWord n) (TypeWord m) | m == n = Just (unsafeCoerce Refl) +matchSingleIntegralType _ _ = Nothing {-# INLINEABLE matchFloatingType #-} matchFloatingType :: FloatingType s -> FloatingType t -> Maybe (s :~: t) -matchFloatingType TypeHalf TypeHalf = Just Refl -matchFloatingType TypeFloat TypeFloat = Just Refl -matchFloatingType TypeDouble TypeDouble = Just Refl -matchFloatingType _ _ = Nothing +matchFloatingType (SingleFloatingType s) (SingleFloatingType t) = matchSingleFloatingType s t +matchFloatingType (VectorFloatingType n s) (VectorFloatingType m t) + | Just Refl <- sameNat' n m + , Just Refl <- matchSingleFloatingType s t + = Just Refl +matchFloatingType _ _ = Nothing + +{-# INLINEABLE matchSingleFloatingType #-} +matchSingleFloatingType :: SingleFloatingType s -> SingleFloatingType t -> Maybe (s :~: t) +matchSingleFloatingType TypeFloat16 TypeFloat16 = Just Refl +matchSingleFloatingType TypeFloat32 TypeFloat32 = Just Refl +matchSingleFloatingType TypeFloat64 TypeFloat64 = Just Refl +matchSingleFloatingType TypeFloat128 TypeFloat128 = Just Refl +matchSingleFloatingType _ _ = Nothing + +{-# INLINEABLE matchTagType #-} +matchTagType :: TagType s -> TagType t -> Maybe (s :~: t) +matchTagType TagBit TagBit = Just Refl +matchTagType TagWord8 TagWord8 = Just Refl +matchTagType TagWord16 TagWord16 = Just Refl +matchTagType _ _ = Nothing -- Auxiliary @@ -904,14 +994,22 @@ commutes f x = case f of PrimNEq{} -> Just (swizzle x) PrimMax{} -> Just (swizzle x) PrimMin{} -> Just (swizzle x) - PrimLAnd -> Just (swizzle x) - PrimLOr -> Just (swizzle x) + PrimLAnd{} -> Just (swizzle x) + PrimLOr{} -> Just (swizzle x) _ -> Nothing where swizzle :: OpenExp env aenv (a',a') -> OpenExp env aenv (a',a') - swizzle exp - | (a `Pair` b) <- exp + swizzle e + | (a `Pair` b) <- e , hashOpenExp a > hashOpenExp b = b `Pair` a -- - | otherwise = exp + | otherwise = e + +data TagDict t where + TagDict :: Eq t => TagDict t + +tagDict :: TagType t -> TagDict t +tagDict TagBit = TagDict +tagDict TagWord8 = TagDict +tagDict TagWord16 = TagDict diff --git a/src/Data/Array/Accelerate/Array/Data.hs b/src/Data/Array/Accelerate/Array/Data.hs index 3885d0b4e..3871b75c1 100644 --- a/src/Data/Array/Accelerate/Array/Data.hs +++ b/src/Data/Array/Accelerate/Array/Data.hs @@ -1,6 +1,8 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -36,16 +38,9 @@ module Data.Array.Accelerate.Array.Data ( touchArrayData, rnfArrayData, - -- * Type macros - HTYPE_INT, HTYPE_WORD, HTYPE_CLONG, HTYPE_CULONG, HTYPE_CCHAR, - -- * Allocator internals registerForeignPtrAllocator, - -- * Utilities for type classes - ScalarArrayDict(..), scalarArrayDict, - SingleArrayDict(..), singleArrayDict, - -- * TemplateHaskell liftArrayData, @@ -55,6 +50,7 @@ import Data.Array.Accelerate.Array.Unique import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type +import Data.Primitive.Bit import Data.Primitive.Vec #ifdef ACCELERATE_DEBUG import Data.Array.Accelerate.Lifetime @@ -67,11 +63,9 @@ import Data.Array.Accelerate.Debug.Internal.Trace import Control.Applicative import Control.DeepSeq import Control.Monad ( (<=<) ) -import Data.Bits +import Data.Bits ( testBit, setBit, clearBit ) import Data.IORef -import Data.Primitive ( sizeOf# ) import Foreign.ForeignPtr -import Foreign.Storable import Formatting hiding ( bytes ) import Language.Haskell.TH.Extra hiding ( Type ) import System.IO.Unsafe @@ -79,6 +73,7 @@ import Prelude hiding ( map import GHC.Exts hiding ( build ) import GHC.ForeignPtr +import GHC.TypeLits import GHC.Types @@ -101,168 +96,398 @@ type family GArrayDataR ba a where type ScalarArrayData a = UniqueArray (ScalarArrayDataR a) --- | Mapping from scalar type to the type as represented in memory in an --- array. +-- | Mapping from scalar type to the type as represented in memory in an array -- type family ScalarArrayDataR t where - ScalarArrayDataR Int = Int - ScalarArrayDataR Int8 = Int8 - ScalarArrayDataR Int16 = Int16 - ScalarArrayDataR Int32 = Int32 - ScalarArrayDataR Int64 = Int64 - ScalarArrayDataR Word = Word - ScalarArrayDataR Word8 = Word8 - ScalarArrayDataR Word16 = Word16 - ScalarArrayDataR Word32 = Word32 - ScalarArrayDataR Word64 = Word64 - ScalarArrayDataR Half = Half - ScalarArrayDataR Float = Float - ScalarArrayDataR Double = Double - ScalarArrayDataR (Vec n t) = ScalarArrayDataR t - - -data ScalarArrayDict a where - ScalarArrayDict :: ( ArrayData a ~ ScalarArrayData a, ScalarArrayDataR a ~ ScalarArrayDataR b ) - => {-# UNPACK #-} !Int -- vector width - -> SingleType b -- base type - -> ScalarArrayDict a - -data SingleArrayDict a where - SingleArrayDict :: ( ArrayData a ~ ScalarArrayData a, ScalarArrayDataR a ~ a ) - => SingleArrayDict a - -scalarArrayDict :: ScalarType a -> ScalarArrayDict a -scalarArrayDict = scalar - where - scalar :: ScalarType a -> ScalarArrayDict a - scalar (VectorScalarType t) = vector t - scalar (SingleScalarType t) - | SingleArrayDict <- singleArrayDict t - = ScalarArrayDict 1 t - - vector :: VectorType a -> ScalarArrayDict a - vector (VectorType w s) - | SingleArrayDict <- singleArrayDict s - = ScalarArrayDict w s - -singleArrayDict :: SingleType a -> SingleArrayDict a -singleArrayDict = single - where - single :: SingleType a -> SingleArrayDict a - single (NumSingleType t) = num t - - num :: NumType a -> SingleArrayDict a - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType a -> SingleArrayDict a - integral TypeInt = SingleArrayDict - integral TypeInt8 = SingleArrayDict - integral TypeInt16 = SingleArrayDict - integral TypeInt32 = SingleArrayDict - integral TypeInt64 = SingleArrayDict - integral TypeWord = SingleArrayDict - integral TypeWord8 = SingleArrayDict - integral TypeWord16 = SingleArrayDict - integral TypeWord32 = SingleArrayDict - integral TypeWord64 = SingleArrayDict - - floating :: FloatingType a -> SingleArrayDict a - floating TypeHalf = SingleArrayDict - floating TypeFloat = SingleArrayDict - floating TypeDouble = SingleArrayDict + ScalarArrayDataR Bit = Word8 + ScalarArrayDataR Int8 = Int8 + ScalarArrayDataR Int16 = Int16 + ScalarArrayDataR Int32 = Int32 + ScalarArrayDataR Int64 = Int64 + ScalarArrayDataR Int128 = Int128 + ScalarArrayDataR Word8 = Word8 + ScalarArrayDataR Word16 = Word16 + ScalarArrayDataR Word32 = Word32 + ScalarArrayDataR Word64 = Word64 + ScalarArrayDataR Word128 = Word128 + ScalarArrayDataR Half = Half + ScalarArrayDataR Float = Float + ScalarArrayDataR Double = Double + ScalarArrayDataR Float128 = Float128 + -- + ScalarArrayDataR (Vec n Bit) = BitMask n + ScalarArrayDataR (Vec n Int8) = Vec n Int8 + ScalarArrayDataR (Vec n Int16) = Vec n Int16 + ScalarArrayDataR (Vec n Int32) = Vec n Int32 + ScalarArrayDataR (Vec n Int64) = Vec n Int64 + ScalarArrayDataR (Vec n Int128) = Vec n Int128 + ScalarArrayDataR (Vec n Word8) = Vec n Word8 + ScalarArrayDataR (Vec n Word16) = Vec n Word16 + ScalarArrayDataR (Vec n Word32) = Vec n Word32 + ScalarArrayDataR (Vec n Word64) = Vec n Word64 + ScalarArrayDataR (Vec n Word128) = Vec n Word128 + ScalarArrayDataR (Vec n Half) = Vec n Half + ScalarArrayDataR (Vec n Float) = Vec n Float + ScalarArrayDataR (Vec n Double) = Vec n Double + ScalarArrayDataR (Vec n Float128) = Vec n Float128 -- Array operations -- ---------------- -newArrayData :: HasCallStack => TupR ScalarType e -> Int -> IO (MutableArrayData e) +newArrayData :: HasCallStack => TypeR e -> Int -> IO (MutableArrayData e) newArrayData TupRunit !_ = return () newArrayData (TupRpair t1 t2) !size = (,) <$> newArrayData t1 size <*> newArrayData t2 size -newArrayData (TupRsingle t) !size - | SingleScalarType s <- t - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = allocateArray size - -- - | VectorScalarType v <- t - , VectorType w s <- v - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = allocateArray (w * size) +newArrayData (TupRsingle _t) !size = scalar _t + where + scalar :: ScalarType t -> IO (MutableArrayData t) + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> IO (MutableArrayData t) + bit TypeBit = allocateArray ((size + 7) `quot` 8) + bit (TypeMask n) = allocateArray (((size * fromInteger (natVal' n)) + 7) `quot` 8) + + num :: NumType t -> IO (MutableArrayData t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t -indexArrayData :: TupR ScalarType e -> ArrayData e -> Int -> e + integral :: IntegralType t -> IO (MutableArrayData t) + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n (fromInteger $ natVal' n) t + where + single :: SingleIntegralType t -> IO (MutableArrayData t) + single = \case + TypeInt8 -> allocateArray size + TypeInt16 -> allocateArray (2 * size) + TypeInt32 -> allocateArray (4 * size) + TypeInt64 -> allocateArray (8 * size) + TypeInt128 -> allocateArray (16 * size) + TypeWord8 -> allocateArray size + TypeWord16 -> allocateArray (2 * size) + TypeWord32 -> allocateArray (4 * size) + TypeWord64 -> allocateArray (8 * size) + TypeWord128 -> allocateArray (16 * size) + + vector :: Proxy# n -> Int -> SingleIntegralType t -> IO (MutableArrayData (Vec n t)) + vector _ !k = \case + TypeInt8 -> allocateArray (k * size) + TypeInt16 -> allocateArray (2 * k * size) + TypeInt32 -> allocateArray (4 * k * size) + TypeInt64 -> allocateArray (8 * k * size) + TypeInt128 -> allocateArray (16 * k * size) + TypeWord8 -> allocateArray (k * size) + TypeWord16 -> allocateArray (2 * k * size) + TypeWord32 -> allocateArray (4 * k * size) + TypeWord64 -> allocateArray (8 * k * size) + TypeWord128 -> allocateArray (16 * k * size) + + floating :: FloatingType t -> IO (MutableArrayData t) + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n (fromInteger $ natVal' n) t + where + single :: SingleFloatingType t -> IO (MutableArrayData t) + single = \case + TypeFloat16 -> allocateArray (2 * size) + TypeFloat32 -> allocateArray (4 * size) + TypeFloat64 -> allocateArray (8 * size) + TypeFloat128 -> allocateArray (16 * size) + + vector :: Proxy# n -> Int -> SingleFloatingType t -> IO (MutableArrayData (Vec n t)) + vector _ !k = \case + TypeFloat16 -> allocateArray (2 * k * size) + TypeFloat32 -> allocateArray (4 * k * size) + TypeFloat64 -> allocateArray (8 * k * size) + TypeFloat128 -> allocateArray (16 * k * size) + + +indexArrayData :: TypeR e -> ArrayData e -> Int -> e indexArrayData tR arr ix = unsafePerformIO $ readArrayData tR arr ix -readArrayData :: forall e. TupR ScalarType e -> MutableArrayData e -> Int -> IO e +readArrayData :: TypeR e -> MutableArrayData e -> Int -> IO e readArrayData TupRunit () !_ = return () readArrayData (TupRpair t1 t2) (a1, a2) !ix = (,) <$> readArrayData t1 a1 ix <*> readArrayData t2 a2 ix -readArrayData (TupRsingle t) arr !ix - | SingleScalarType s <- t - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = unsafeReadArray arr ix - -- - | VectorScalarType v <- t - , VectorType w s <- v - , I# w# <- w - , I# ix# <- ix - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = let - !bytes# = w# *# sizeOf# (undefined :: ScalarArrayDataR e) - !addr# = unPtr# (unsafeUniqueArrayPtr arr) `plusAddr#` (ix# *# bytes#) - in - IO $ \s0 -> - case newAlignedPinnedByteArray# bytes# 16# s0 of { (# s1, mba# #) -> - case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> - case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> - (# s3, Vec ba# #) - }}} - -writeArrayData :: forall e. TupR ScalarType e -> MutableArrayData e -> Int -> e -> IO () -writeArrayData TupRunit () !_ () = return () +readArrayData (TupRsingle _t) arr !ix = scalar _t arr ix + where + scalar :: ScalarType t -> MutableArrayData t -> Int -> IO t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> MutableArrayData t -> Int -> IO t + bit TypeMask{} ua i = unMask <$> unsafeReadArray ua i + bit TypeBit ua i = + let (q,r) = quotRem i 8 + in do + w <- unsafeReadArray ua q + return $ Bit (testBit w r) + + num :: NumType t -> MutableArrayData t -> Int -> IO t + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> MutableArrayData t -> Int -> IO t + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> MutableArrayData t -> Int -> IO t + single = \case + TypeInt8 -> unsafeReadArray + TypeInt16 -> unsafeReadArray + TypeInt32 -> unsafeReadArray + TypeInt64 -> unsafeReadArray + TypeInt128 -> unsafeReadArray + TypeWord8 -> unsafeReadArray + TypeWord16 -> unsafeReadArray + TypeWord32 -> unsafeReadArray + TypeWord64 -> unsafeReadArray + TypeWord128 -> unsafeReadArray + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> MutableArrayData (Vec n t) -> Int -> IO (Vec n t) + vector _ = \case + TypeInt8 -> unsafeReadArray + TypeInt16 -> unsafeReadArray + TypeInt32 -> unsafeReadArray + TypeInt64 -> unsafeReadArray + TypeInt128 -> unsafeReadArray + TypeWord8 -> unsafeReadArray + TypeWord16 -> unsafeReadArray + TypeWord32 -> unsafeReadArray + TypeWord64 -> unsafeReadArray + TypeWord128 -> unsafeReadArray + + floating :: FloatingType t -> MutableArrayData t -> Int -> IO t + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> MutableArrayData t -> Int -> IO t + single = \case + TypeFloat16 -> unsafeReadArray + TypeFloat32 -> unsafeReadArray + TypeFloat64 -> unsafeReadArray + TypeFloat128 -> unsafeReadArray + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> MutableArrayData (Vec n t) -> Int -> IO (Vec n t) + vector _ = \case + TypeFloat16 -> unsafeReadArray + TypeFloat32 -> unsafeReadArray + TypeFloat64 -> unsafeReadArray + TypeFloat128 -> unsafeReadArray + + +writeArrayData :: TypeR e -> MutableArrayData e -> Int -> e -> IO () +writeArrayData TupRunit () !_ !() = return () writeArrayData (TupRpair t1 t2) (a1, a2) !ix (v1, v2) = writeArrayData t1 a1 ix v1 >> writeArrayData t2 a2 ix v2 -writeArrayData (TupRsingle t) arr !ix !val - | SingleScalarType s <- t - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = unsafeWriteArray arr ix val - -- - | VectorScalarType v <- t - , VectorType w s <- v - , Vec ba# <- val - , I# w# <- w - , I# ix# <- ix - , SingleDict <- singleDict s - , SingleArrayDict <- singleArrayDict s - = let - !bytes# = w# *# sizeOf# (undefined :: ScalarArrayDataR e) - !addr# = unPtr# (unsafeUniqueArrayPtr arr) `plusAddr#` (ix# *# bytes#) - in - IO $ \s0 -> case copyByteArrayToAddr# ba# 0# addr# bytes# s0 of - s1 -> (# s1, () #) +writeArrayData (TupRsingle _t) arr !ix !val = scalar _t arr ix val + where + scalar :: ScalarType t -> MutableArrayData t -> Int -> t -> IO () + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> MutableArrayData t -> Int -> t -> IO () + bit TypeMask{} ua i m = unsafeWriteArray ua i (BitMask m) + bit TypeBit ua i (Bit b) = + let (q,r) = quotRem i 8 + update x = if b then setBit x r else clearBit x r + in do + w <- unsafeReadArray ua q + unsafeWriteArray ua q (update w) + + num :: NumType t -> MutableArrayData t -> Int -> t -> IO () + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> MutableArrayData t -> Int -> t -> IO () + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> MutableArrayData t -> Int -> t -> IO () + single = \case + TypeInt8 -> unsafeWriteArray + TypeInt16 -> unsafeWriteArray + TypeInt32 -> unsafeWriteArray + TypeInt64 -> unsafeWriteArray + TypeInt128 -> unsafeWriteArray + TypeWord8 -> unsafeWriteArray + TypeWord16 -> unsafeWriteArray + TypeWord32 -> unsafeWriteArray + TypeWord64 -> unsafeWriteArray + TypeWord128 -> unsafeWriteArray + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> MutableArrayData (Vec n t) -> Int -> Vec n t -> IO () + vector _ = \case + TypeInt8 -> unsafeWriteArray + TypeInt16 -> unsafeWriteArray + TypeInt32 -> unsafeWriteArray + TypeInt64 -> unsafeWriteArray + TypeInt128 -> unsafeWriteArray + TypeWord8 -> unsafeWriteArray + TypeWord16 -> unsafeWriteArray + TypeWord32 -> unsafeWriteArray + TypeWord64 -> unsafeWriteArray + TypeWord128 -> unsafeWriteArray + + floating :: FloatingType t -> MutableArrayData t -> Int -> t -> IO () + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> MutableArrayData t -> Int -> t -> IO () + single = \case + TypeFloat16 -> unsafeWriteArray + TypeFloat32 -> unsafeWriteArray + TypeFloat64 -> unsafeWriteArray + TypeFloat128 -> unsafeWriteArray + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> MutableArrayData (Vec n t) -> Int -> Vec n t -> IO () + vector _ = \case + TypeFloat16 -> unsafeWriteArray + TypeFloat32 -> unsafeWriteArray + TypeFloat64 -> unsafeWriteArray + TypeFloat128 -> unsafeWriteArray unsafeArrayDataPtr :: ScalarType e -> ArrayData e -> Ptr (ScalarArrayDataR e) -unsafeArrayDataPtr t arr - | ScalarArrayDict{} <- scalarArrayDict t - = unsafeUniqueArrayPtr arr +unsafeArrayDataPtr = scalar + where + scalar :: ScalarType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + bit TypeBit = unsafeUniqueArrayPtr + bit TypeMask{} = unsafeUniqueArrayPtr + + num :: NumType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + single = \case + TypeInt8 -> unsafeUniqueArrayPtr + TypeInt16 -> unsafeUniqueArrayPtr + TypeInt32 -> unsafeUniqueArrayPtr + TypeInt64 -> unsafeUniqueArrayPtr + TypeInt128 -> unsafeUniqueArrayPtr + TypeWord8 -> unsafeUniqueArrayPtr + TypeWord16 -> unsafeUniqueArrayPtr + TypeWord32 -> unsafeUniqueArrayPtr + TypeWord64 -> unsafeUniqueArrayPtr + TypeWord128 -> unsafeUniqueArrayPtr + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> ArrayData (Vec n t) -> Ptr (ScalarArrayDataR (Vec n t)) + vector _ = \case + TypeInt8 -> unsafeUniqueArrayPtr + TypeInt16 -> unsafeUniqueArrayPtr + TypeInt32 -> unsafeUniqueArrayPtr + TypeInt64 -> unsafeUniqueArrayPtr + TypeInt128 -> unsafeUniqueArrayPtr + TypeWord8 -> unsafeUniqueArrayPtr + TypeWord16 -> unsafeUniqueArrayPtr + TypeWord32 -> unsafeUniqueArrayPtr + TypeWord64 -> unsafeUniqueArrayPtr + TypeWord128 -> unsafeUniqueArrayPtr + + floating :: FloatingType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> ArrayData t -> Ptr (ScalarArrayDataR t) + single = \case + TypeFloat16 -> unsafeUniqueArrayPtr + TypeFloat32 -> unsafeUniqueArrayPtr + TypeFloat64 -> unsafeUniqueArrayPtr + TypeFloat128 -> unsafeUniqueArrayPtr + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> ArrayData (Vec n t) -> Ptr (ScalarArrayDataR (Vec n t)) + vector _ = \case + TypeFloat16 -> unsafeUniqueArrayPtr + TypeFloat32 -> unsafeUniqueArrayPtr + TypeFloat64 -> unsafeUniqueArrayPtr + TypeFloat128 -> unsafeUniqueArrayPtr touchArrayData :: TupR ScalarType e -> ArrayData e -> IO () touchArrayData TupRunit () = return () touchArrayData (TupRpair t1 t2) (a1, a2) = touchArrayData t1 a1 >> touchArrayData t2 a2 -touchArrayData (TupRsingle t) arr - | ScalarArrayDict{} <- scalarArrayDict t - = touchUniqueArray arr +touchArrayData (TupRsingle ta) arr = scalar ta arr + where + scalar :: ScalarType t -> ArrayData t -> IO () + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> ArrayData t -> IO () + bit TypeBit = touchUniqueArray + bit TypeMask{} = touchUniqueArray + + num :: NumType t -> ArrayData t -> IO () + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ArrayData t -> IO () + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> ArrayData t -> IO () + single = \case + TypeInt8 -> touchUniqueArray + TypeInt16 -> touchUniqueArray + TypeInt32 -> touchUniqueArray + TypeInt64 -> touchUniqueArray + TypeInt128 -> touchUniqueArray + TypeWord8 -> touchUniqueArray + TypeWord16 -> touchUniqueArray + TypeWord32 -> touchUniqueArray + TypeWord64 -> touchUniqueArray + TypeWord128 -> touchUniqueArray + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> ArrayData (Vec n t) -> IO () + vector _ = \case + TypeInt8 -> touchUniqueArray + TypeInt16 -> touchUniqueArray + TypeInt32 -> touchUniqueArray + TypeInt64 -> touchUniqueArray + TypeInt128 -> touchUniqueArray + TypeWord8 -> touchUniqueArray + TypeWord16 -> touchUniqueArray + TypeWord32 -> touchUniqueArray + TypeWord64 -> touchUniqueArray + TypeWord128 -> touchUniqueArray + + floating :: FloatingType t -> ArrayData t -> IO () + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> ArrayData t -> IO () + single = \case + TypeFloat16 -> touchUniqueArray + TypeFloat32 -> touchUniqueArray + TypeFloat64 -> touchUniqueArray + TypeFloat128 -> touchUniqueArray + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> ArrayData (Vec n t) -> IO () + vector _ = \case + TypeFloat16 -> touchUniqueArray + TypeFloat32 -> touchUniqueArray + TypeFloat64 -> touchUniqueArray + TypeFloat128 -> touchUniqueArray rnfArrayData :: TupR ScalarType e -> ArrayData e -> () rnfArrayData TupRunit () = () rnfArrayData (TupRpair t1 t2) (a1, a2) = rnfArrayData t1 a1 `seq` rnfArrayData t2 a2 `seq` () rnfArrayData (TupRsingle t) arr = rnf (unsafeArrayDataPtr t arr) -unPtr# :: Ptr a -> Addr# -unPtr# (Ptr addr#) = addr# -- | Safe combination of creating and fast freezing of array data. -- @@ -273,22 +498,20 @@ runArrayData st = unsafePerformIO $ do (mad, r) <- st return (mad, r) --- Allocate a new array with enough storage to hold the given number of --- elements. +-- Allocate a new array of the given number of bytes. -- -- The array is uninitialised and, in particular, allocated lazily. The latter -- is important because it means that for backends that have discrete memory -- spaces (e.g. GPUs), we will not increase host memory pressure simply to track -- intermediate arrays that contain meaningful data only on the device. -- -allocateArray :: forall e. (HasCallStack, Storable e) => Int -> IO (UniqueArray e) +allocateArray :: forall e. HasCallStack => Int -> IO (UniqueArray e) allocateArray !size = internalCheck "size must be >= 0" (size >= 0) $ do arr <- newUniqueArray <=< unsafeInterleaveIO $ do - let bytes = size * sizeOf (undefined :: e) new <- readIORef __mallocForeignPtrBytes - ptr <- new bytes - traceM dump_gc ("gc: allocated new host array (size=" % int % ", ptr=" % build % ")") bytes (unsafeForeignPtrToPtr ptr) - local_memory_alloc (unsafeForeignPtrToPtr ptr) bytes + ptr <- new size + traceM dump_gc ("gc: allocated new host array (size=" % int % ", ptr=" % build % ")") size (unsafeForeignPtrToPtr ptr) + local_memory_alloc (unsafeForeignPtrToPtr ptr) size return (castForeignPtr ptr) #ifdef ACCELERATE_DEBUG addFinalizer (uniqueArrayData arr) (local_memory_free (unsafeUniqueArrayPtr arr)) @@ -326,7 +549,7 @@ mallocPlainForeignPtrBytesAligned (I# size#) = IO $ \s0 -> liftArrayData :: Int -> TypeR e -> ArrayData e -> CodeQ (ArrayData e) -liftArrayData n = tuple +liftArrayData !size = tuple where tuple :: TypeR e -> ArrayData e -> CodeQ (ArrayData e) tuple TupRunit () = [|| () ||] @@ -334,66 +557,60 @@ liftArrayData n = tuple tuple (TupRsingle s) adata = scalar s adata scalar :: ScalarType e -> ArrayData e -> CodeQ (ArrayData e) - scalar (SingleScalarType t) = single t - scalar (VectorScalarType t) = vector t - - vector :: forall n e. VectorType (Vec n e) -> ArrayData (Vec n e) -> CodeQ (ArrayData (Vec n e)) - vector (VectorType w t) - | SingleArrayDict <- singleArrayDict t - = liftArrayData (w * n) (TupRsingle (SingleScalarType t)) + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - single :: SingleType e -> ArrayData e -> CodeQ (ArrayData e) - single (NumSingleType t) = num t + bit :: BitType e -> ArrayData e -> CodeQ (ArrayData e) + bit TypeBit ua = liftUniqueArray ((size + 7) `quot` 8) ua + bit (TypeMask n) ua = liftUniqueArray (((size * fromInteger (natVal' n)) + 7) `quot` 8) ua num :: NumType e -> ArrayData e -> CodeQ (ArrayData e) num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType e -> ArrayData e -> CodeQ (ArrayData e) - integral TypeInt = liftUniqueArray n - integral TypeInt8 = liftUniqueArray n - integral TypeInt16 = liftUniqueArray n - integral TypeInt32 = liftUniqueArray n - integral TypeInt64 = liftUniqueArray n - integral TypeWord = liftUniqueArray n - integral TypeWord8 = liftUniqueArray n - integral TypeWord16 = liftUniqueArray n - integral TypeWord32 = liftUniqueArray n - integral TypeWord64 = liftUniqueArray n + integral = \case + SingleIntegralType t -> single t size + VectorIntegralType n t -> vector n t (size * fromInteger (natVal' n)) + where + single :: SingleIntegralType e -> Int -> ArrayData e -> CodeQ (ArrayData e) + single TypeInt8 = liftUniqueArray + single TypeInt16 = liftUniqueArray + single TypeInt32 = liftUniqueArray + single TypeInt64 = liftUniqueArray + single TypeInt128 = liftUniqueArray + single TypeWord8 = liftUniqueArray + single TypeWord16 = liftUniqueArray + single TypeWord32 = liftUniqueArray + single TypeWord64 = liftUniqueArray + single TypeWord128 = liftUniqueArray + + vector :: KnownNat n => Proxy# n -> SingleIntegralType e -> Int -> ArrayData (Vec n e) -> CodeQ (ArrayData (Vec n e)) + vector _ TypeInt8 = liftUniqueArray + vector _ TypeInt16 = liftUniqueArray + vector _ TypeInt32 = liftUniqueArray + vector _ TypeInt64 = liftUniqueArray + vector _ TypeInt128 = liftUniqueArray + vector _ TypeWord8 = liftUniqueArray + vector _ TypeWord16 = liftUniqueArray + vector _ TypeWord32 = liftUniqueArray + vector _ TypeWord64 = liftUniqueArray + vector _ TypeWord128 = liftUniqueArray floating :: FloatingType e -> ArrayData e -> CodeQ (ArrayData e) - floating TypeHalf = liftUniqueArray n - floating TypeFloat = liftUniqueArray n - floating TypeDouble = liftUniqueArray n - --- Determine the underlying type of a Haskell CLong or CULong. --- -runQ [d| type HTYPE_INT = $( - case finiteBitSize (undefined::Int) of - 32 -> [t| Int32 |] - 64 -> [t| Int64 |] - _ -> error "I don't know what architecture I am" ) |] - -runQ [d| type HTYPE_WORD = $( - case finiteBitSize (undefined::Word) of - 32 -> [t| Word32 |] - 64 -> [t| Word64 |] - _ -> error "I don't know what architecture I am" ) |] - -runQ [d| type HTYPE_CLONG = $( - case finiteBitSize (undefined::CLong) of - 32 -> [t| Int32 |] - 64 -> [t| Int64 |] - _ -> error "I don't know what architecture I am" ) |] - -runQ [d| type HTYPE_CULONG = $( - case finiteBitSize (undefined::CULong) of - 32 -> [t| Word32 |] - 64 -> [t| Word64 |] - _ -> error "I don't know what architecture I am" ) |] - -runQ [d| type HTYPE_CCHAR = $( - if isSigned (undefined::CChar) - then [t| Int8 |] - else [t| Word8 |] ) |] + floating = \case + SingleFloatingType t -> single t size + VectorFloatingType n t -> vector n t (size * fromInteger (natVal' n)) + where + single :: SingleFloatingType e -> Int -> ArrayData e -> CodeQ (ArrayData e) + single TypeFloat16 = liftUniqueArray + single TypeFloat32 = liftUniqueArray + single TypeFloat64 = liftUniqueArray + single TypeFloat128 = liftUniqueArray + + vector :: KnownNat n => Proxy# n -> SingleFloatingType e -> Int -> ArrayData (Vec n e) -> CodeQ (ArrayData (Vec n e)) + vector _ TypeFloat16 = liftUniqueArray + vector _ TypeFloat32 = liftUniqueArray + vector _ TypeFloat64 = liftUniqueArray + vector _ TypeFloat128 = liftUniqueArray diff --git a/src/Data/Array/Accelerate/Array/Remote/Class.hs b/src/Data/Array/Accelerate/Array/Remote/Class.hs index 7a871bd1a..067161121 100644 --- a/src/Data/Array/Accelerate/Array/Remote/Class.hs +++ b/src/Data/Array/Accelerate/Array/Remote/Class.hs @@ -54,10 +54,10 @@ class (Applicative m, Monad m, MonadCatch m, MonadMask m) => RemoteMemory m wher mallocRemote :: Int -> m (Maybe (RemotePtr m Word8)) -- | Copy the given number of elements from the host array into remote memory. - pokeRemote :: SingleType e -> Int -> RemotePtr m (ScalarArrayDataR e) -> ArrayData e -> m () + pokeRemote :: ScalarType e -> Int -> RemotePtr m (ScalarArrayDataR e) -> ArrayData e -> m () -- | Copy the given number of elements from remote memory to the host array. - peekRemote :: SingleType e -> Int -> RemotePtr m (ScalarArrayDataR e) -> MutableArrayData e -> m () + peekRemote :: ScalarType e -> Int -> RemotePtr m (ScalarArrayDataR e) -> MutableArrayData e -> m () -- | Cast a remote pointer. castRemotePtr :: RemotePtr m a -> RemotePtr m b diff --git a/src/Data/Array/Accelerate/Array/Unique.hs b/src/Data/Array/Accelerate/Array/Unique.hs index 10aa97bd2..875203346 100644 --- a/src/Data/Array/Accelerate/Array/Unique.hs +++ b/src/Data/Array/Accelerate/Array/Unique.hs @@ -20,6 +20,7 @@ import Data.Array.Accelerate.Lifetime import Control.Applicative import Control.Concurrent.Unique import Control.DeepSeq +import Data.Coerce import Data.Word import Foreign.ForeignPtr import Foreign.ForeignPtr.Unsafe @@ -109,6 +110,15 @@ unsafeUniqueArrayPtr :: UniqueArray a -> Ptr a unsafeUniqueArrayPtr = unsafeForeignPtrToPtr . unsafeGetValue . uniqueArrayData +-- | Cast a unique array parameterised by one type into another type +-- +-- @since 1.4.0.0 +-- +{-# INLINE castUniqueArray #-} +castUniqueArray :: UniqueArray a -> UniqueArray b +castUniqueArray = coerce + + -- | Ensure that the unique array is alive at the given place in a sequence of -- IO actions. Note that this does not force the actual array payload. -- diff --git a/src/Data/Array/Accelerate/Classes/Bounded.hs b/src/Data/Array/Accelerate/Classes/Bounded.hs index c567fb14c..2229556ac 100644 --- a/src/Data/Array/Accelerate/Classes/Bounded.hs +++ b/src/Data/Array/Accelerate/Classes/Bounded.hs @@ -21,14 +21,15 @@ module Data.Array.Accelerate.Classes.Bounded ( ) where -import Data.Array.Accelerate.Array.Data -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type +import qualified Data.Primitive.Vec as Prim -import Prelude ( ($), (<$>), Num(..), Char, Bool, show, concat, map, mapM ) import Language.Haskell.TH.Extra hiding ( Exp ) +import Prelude hiding ( Bounded ) import qualified Prelude as P @@ -42,113 +43,57 @@ instance P.Bounded (Exp ()) where minBound = constant () maxBound = constant () -instance P.Bounded (Exp Int) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Int8) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Int16) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Int32) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Int64) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word8) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word16) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word32) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp Word64) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp CShort) where - minBound = mkBitcast (mkMinBound @Int16) - maxBound = mkBitcast (mkMaxBound @Int16) - -instance P.Bounded (Exp CUShort) where - minBound = mkBitcast (mkMinBound @Word16) - maxBound = mkBitcast (mkMaxBound @Word16) - -instance P.Bounded (Exp CInt) where - minBound = mkBitcast (mkMinBound @Int32) - maxBound = mkBitcast (mkMaxBound @Int32) - -instance P.Bounded (Exp CUInt) where - minBound = mkBitcast (mkMinBound @Word32) - maxBound = mkBitcast (mkMaxBound @Word32) - -instance P.Bounded (Exp CLong) where - minBound = mkBitcast (mkMinBound @HTYPE_CLONG) - maxBound = mkBitcast (mkMaxBound @HTYPE_CLONG) - -instance P.Bounded (Exp CULong) where - minBound = mkBitcast (mkMinBound @HTYPE_CULONG) - maxBound = mkBitcast (mkMaxBound @HTYPE_CULONG) - -instance P.Bounded (Exp CLLong) where - minBound = mkBitcast (mkMinBound @Int64) - maxBound = mkBitcast (mkMaxBound @Int64) - -instance P.Bounded (Exp CULLong) where - minBound = mkBitcast (mkMinBound @Word64) - maxBound = mkBitcast (mkMaxBound @Word64) - instance P.Bounded (Exp Bool) where minBound = constant P.minBound maxBound = constant P.maxBound instance P.Bounded (Exp Char) where - minBound = mkMinBound - maxBound = mkMaxBound - -instance P.Bounded (Exp CChar) where - minBound = mkBitcast (mkMinBound @HTYPE_CCHAR) - maxBound = mkBitcast (mkMaxBound @HTYPE_CCHAR) - -instance P.Bounded (Exp CSChar) where - minBound = mkBitcast (mkMinBound @Int8) - maxBound = mkBitcast (mkMaxBound @Int8) - -instance P.Bounded (Exp CUChar) where - minBound = mkBitcast (mkMinBound @Word8) - maxBound = mkBitcast (mkMaxBound @Word8) - -$(runQ $ do - let - mkInstance :: Int -> Q [Dec] - mkInstance n = - let - xs = [ mkName ('x':show i) | i <- [0 .. n-1] ] - cst = tupT (map (\x -> [t| Bounded $(varT x) |]) xs) - res = tupT (map varT xs) - app x = appsE (conE (mkName ('T':show n)) : P.replicate n x) - in - [d| instance $cst => P.Bounded (Exp $res) where - minBound = $(app [| P.minBound |]) - maxBound = $(app [| P.maxBound |]) - |] - -- - concat <$> mapM mkInstance [2..16] - ) + minBound = constant P.minBound + maxBound = constant P.maxBound + +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + mkBounded :: Name -> Q [Dec] + mkBounded a = + [d| instance P.Bounded (Exp $(conT a)) where + minBound = constant P.minBound + maxBound = constant P.maxBound + + instance KnownNat n => P.Bounded (Exp (Vec n $(conT a))) where + minBound = constant (Vec (Prim.splat minBound)) + maxBound = constant (Vec (Prim.splat maxBound)) + |] + + mkTuple :: Int -> Q [Dec] + mkTuple n = + let + xs = [ mkName ('x':show i) | i <- [0 .. n-1] ] + cst = tupT (map (\x -> [t| Bounded $(varT x) |]) xs) + res = tupT (map varT xs) + app x = appsE (conE (mkName ('T':show n)) : P.replicate n x) + in + [d| instance $cst => P.Bounded (Exp $res) where + minBound = $(app [| P.minBound |]) + maxBound = $(app [| P.maxBound |]) + |] + -- + as <- mapM mkBounded integralTypes + ts <- mapM mkTuple [2..16] + return $ concat (as ++ ts) diff --git a/src/Data/Array/Accelerate/Classes/Enum.hs b/src/Data/Array/Accelerate/Classes/Enum.hs index 84b344273..cd2f3e2c6 100644 --- a/src/Data/Array/Accelerate/Classes/Enum.hs +++ b/src/Data/Array/Accelerate/Classes/Enum.hs @@ -1,7 +1,7 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Enum @@ -22,10 +22,13 @@ module Data.Array.Accelerate.Classes.Enum ( import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Text.Printf -import Prelude ( ($), String, error, unlines, succ, pred ) +import Control.Monad +import Language.Haskell.TH hiding ( Exp ) +import Text.Printf +import Prelude hiding ( Num, Enum ) import qualified Prelude as P @@ -33,145 +36,6 @@ import qualified Prelude as P -- type Enum a = P.Enum (Exp a) - -instance P.Enum (Exp Int) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Int8) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Int16) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Int32) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Int64) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word8) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word16) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word32) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Word64) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CInt) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CUInt) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CLong) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CULong) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CLLong) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CULLong) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CShort) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CUShort) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Half) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Float) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp Double) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CFloat) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - -instance P.Enum (Exp CDouble) where - succ = defaultSucc - pred = defaultPred - toEnum = defaultToEnum - fromEnum = defaultFromEnum - defaultSucc :: Num a => Exp a -> Exp a defaultSucc x = x + 1 @@ -187,9 +51,54 @@ defaultFromEnum = preludeError "fromEnum" preludeError :: String -> a preludeError x = error - $ unlines [ printf "Prelude.%s is not supported for Accelerate types" x - , "" + $ unlines [ printf "Prelude.%s is not supported for Accelerate types" x , "" , "These Prelude.Enum instances are present only to fulfil superclass" , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + mkEnum :: Name -> Q [Dec] + mkEnum a = + [d| instance P.Enum (Exp $(conT a)) where + succ = defaultSucc + pred = defaultPred + toEnum = defaultToEnum + fromEnum = defaultFromEnum + + instance KnownNat n => P.Enum (Exp (Vec n $(conT a))) where + succ = defaultSucc + pred = defaultPred + toEnum = defaultToEnum + fromEnum = defaultFromEnum + |] + in + concat <$> mapM mkEnum numTypes + diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index 6985facdd..dfb9bd492 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -1,7 +1,10 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} @@ -22,7 +25,7 @@ module Data.Array.Accelerate.Classes.Eq ( - Bool(..), pattern True_, pattern False_, + Bool, pattern True, pattern False, Eq(..), (&&), (&&!), (||), (||!), @@ -30,18 +33,20 @@ module Data.Array.Accelerate.Classes.Eq ( ) where -import Data.Array.Accelerate.AST.Idx -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Error import Data.Array.Accelerate.Pattern.Bool +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type +import {-# SOURCE #-} Data.Array.Accelerate.Classes.VEq -import Data.Bool ( Bool(..) ) -import Data.Char ( Char ) import Text.Printf +import Data.Char import Prelude ( ($), String, Num(..), Ordering(..), show, error, return, concat, map, zipWith, foldr1, mapM ) import Language.Haskell.TH.Extra hiding ( Exp ) import qualified Prelude as P @@ -56,10 +61,7 @@ infix 4 /= infixr 3 && (&&) :: Exp Bool -> Exp Bool -> Exp Bool (&&) (Exp x) (Exp y) = - mkExp $ SmartExp (Cond (SmartExp $ Prj PairIdxLeft x) - (SmartExp $ Prj PairIdxLeft y) - (SmartExp $ Const scalarTypeWord8 0)) - `Pair` SmartExp Nil + mkExp $ Cond x y (SmartExp $ Const (scalarType @PrimBool) 0) -- | Conjunction: True if both arguments are true. This is a strict version of -- '(&&)': it will always evaluate both arguments, even when the first is false. @@ -68,7 +70,7 @@ infixr 3 && -- infixr 3 &&! (&&!) :: Exp Bool -> Exp Bool -> Exp Bool -(&&!) = mkLAnd +(&&!) = mkPrimBinary $ PrimLAnd bitType -- | Disjunction: True if either argument is true. This is a short-circuit -- operator, so the second argument will be evaluated only if the first is @@ -77,10 +79,7 @@ infixr 3 &&! infixr 2 || (||) :: Exp Bool -> Exp Bool -> Exp Bool (||) (Exp x) (Exp y) = - mkExp $ SmartExp (Cond (SmartExp $ Prj PairIdxLeft x) - (SmartExp $ Const scalarTypeWord8 1) - (SmartExp $ Prj PairIdxLeft y)) - `Pair` SmartExp Nil + mkExp $ Cond x (SmartExp $ Const (scalarType @PrimBool) 1) y -- | Disjunction: True if either argument is true. This is a strict version of @@ -90,34 +89,36 @@ infixr 2 || -- infixr 2 ||! (||!) :: Exp Bool -> Exp Bool -> Exp Bool -(||!) = mkLOr +(||!) = mkPrimBinary $ PrimLOr bitType -- | Logical negation -- not :: Exp Bool -> Exp Bool -not = mkLNot +not = mkPrimUnary $ PrimLNot bitType --- | The 'Eq' class defines equality '==' and inequality '/=' for scalar +-- | The 'Eq' class defines equality '(==)' and inequality '(/=)' for -- Accelerate expressions. -- --- For convenience, we include 'Elt' as a superclass. +-- Vector types behave analogously to tuple types. For testing equality +-- lane-wise on each element of a vector, see the class +-- 'Data.Array.Accelerate.VEq'. -- class Elt a => Eq a where (==) :: Exp a -> Exp a -> Exp Bool (/=) :: Exp a -> Exp a -> Exp Bool {-# MINIMAL (==) | (/=) #-} - x == y = mkLNot (x /= y) - x /= y = mkLNot (x == y) + x == y = not (x /= y) + x /= y = not (x == y) instance Eq () where - _ == _ = True_ - _ /= _ = False_ + _ == _ = True + _ /= _ = False instance Eq Z where - _ == _ = True_ - _ /= _ = False_ + _ == _ = True + _ /= _ = False -- Instances of 'Prelude.Eq' don't make sense with the standard signatures as -- the return type is fixed to 'Bool'. This instance is provided to provide @@ -127,8 +128,15 @@ instance P.Eq (Exp a) where (==) = preludeError "Eq.(==)" "(==)" (/=) = preludeError "Eq.(/=)" "(/=)" -preludeError :: String -> String -> a -preludeError x y = error (printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y) +preludeError :: HasCallStack => String -> String -> a +preludeError x y + = error + $ P.unlines [ printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y + , "" + , "These Prelude.Eq instances are present only to fulfil superclass" + , "constraints for subsequent classes in the standard Haskell numeric" + , "hierarchy." + ] runQ $ do let @@ -139,11 +147,13 @@ runQ $ do , ''Int16 , ''Int32 , ''Int64 + , ''Int128 , ''Word , ''Word8 , ''Word16 , ''Word32 , ''Word64 + , ''Word128 ] floatingTypes :: [Name] @@ -151,6 +161,7 @@ runQ $ do [ ''Half , ''Float , ''Double + , ''Float128 ] nonNumTypes :: [Name] @@ -158,28 +169,11 @@ runQ $ do [ ''Char ] - cTypes :: [Name] - cTypes = - [ ''CInt - , ''CUInt - , ''CLong - , ''CULong - , ''CLLong - , ''CULLong - , ''CShort - , ''CUShort - , ''CChar - , ''CUChar - , ''CSChar - , ''CFloat - , ''CDouble - ] - mkPrim :: Name -> Q [Dec] mkPrim t = [d| instance Eq $(conT t) where - (==) = mkEq - (/=) = mkNEq + (==) = mkPrimBinary $ PrimEq scalarType + (/=) = mkPrimBinary $ PrimNEq scalarType |] mkTup :: Int -> Q [Dec] @@ -199,19 +193,22 @@ runQ $ do is <- mapM mkPrim integralTypes fs <- mapM mkPrim floatingTypes ns <- mapM mkPrim nonNumTypes - cs <- mapM mkPrim cTypes ts <- mapM mkTup [2..16] - return $ concat (concat [is,fs,ns,cs,ts]) + return $ concat (concat [is,fs,ns,ts]) instance Eq sh => Eq (sh :. Int) where x == y = indexHead x == indexHead y && indexTail x == indexTail y x /= y = indexHead x /= indexHead y || indexTail x /= indexTail y instance Eq Bool where - x == y = mkCoerce x == (mkCoerce y :: Exp PrimBool) - x /= y = mkCoerce x /= (mkCoerce y :: Exp PrimBool) + (==) = mkPrimBinary $ PrimEq scalarType + (/=) = mkPrimBinary $ PrimNEq scalarType instance Eq Ordering where x == y = mkCoerce x == (mkCoerce y :: Exp TAG) x /= y = mkCoerce x /= (mkCoerce y :: Exp TAG) +instance VEq n a => Eq (Vec n a) where + x == y = vand (x ==* y) + x /= y = vor (x /=* y) + diff --git a/src/Data/Array/Accelerate/Classes/Floating.hs b/src/Data/Array/Accelerate/Classes/Floating.hs index b2f067af9..263fd9d1d 100644 --- a/src/Data/Array/Accelerate/Classes/Floating.hs +++ b/src/Data/Array/Accelerate/Classes/Floating.hs @@ -1,8 +1,13 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Floating @@ -30,11 +35,16 @@ module Data.Array.Accelerate.Classes.Floating ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type +import qualified Data.Primitive.Vec as Prim import Data.Array.Accelerate.Classes.Fractional +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Fractional, Floating ) import qualified Prelude as P @@ -42,104 +52,58 @@ import qualified Prelude as P -- type Floating a = (Fractional a, P.Floating (Exp a)) +runQ $ + let + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] -instance P.Floating (Exp Half) where - pi = mkPi - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase + thFloating :: Name -> Q [Dec] + thFloating a = + [d| instance P.Floating (Exp $(conT a)) where + pi = constant pi + sin = mkPrimUnary $ PrimSin floatingType + cos = mkPrimUnary $ PrimCos floatingType + tan = mkPrimUnary $ PrimTan floatingType + asin = mkPrimUnary $ PrimAsin floatingType + acos = mkPrimUnary $ PrimAcos floatingType + atan = mkPrimUnary $ PrimAtan floatingType + sinh = mkPrimUnary $ PrimSinh floatingType + cosh = mkPrimUnary $ PrimCosh floatingType + tanh = mkPrimUnary $ PrimTanh floatingType + asinh = mkPrimUnary $ PrimAsinh floatingType + acosh = mkPrimUnary $ PrimAcosh floatingType + atanh = mkPrimUnary $ PrimAtanh floatingType + exp = mkPrimUnary $ PrimExpFloating floatingType + sqrt = mkPrimUnary $ PrimSqrt floatingType + log = mkPrimUnary $ PrimLog floatingType + (**) = mkPrimBinary $ PrimFPow floatingType + logBase = mkPrimBinary $ PrimLogBase floatingType -instance P.Floating (Exp Float) where - pi = mkPi - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase - -instance P.Floating (Exp Double) where - pi = mkPi - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase - -instance P.Floating (Exp CFloat) where - pi = mkBitcast (mkPi @Float) - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase - -instance P.Floating (Exp CDouble) where - pi = mkBitcast (mkPi @Double) - sin = mkSin - cos = mkCos - tan = mkTan - asin = mkAsin - acos = mkAcos - atan = mkAtan - sinh = mkSinh - cosh = mkCosh - tanh = mkTanh - asinh = mkAsinh - acosh = mkAcosh - atanh = mkAtanh - exp = mkExpFloating - sqrt = mkSqrt - log = mkLog - (**) = mkFPow - logBase = mkLogBase + instance KnownNat n => P.Floating (Exp (Vec n $(conT a))) where + pi = constant (Vec (Prim.splat pi)) + sin = mkPrimUnary $ PrimSin floatingType + cos = mkPrimUnary $ PrimCos floatingType + tan = mkPrimUnary $ PrimTan floatingType + asin = mkPrimUnary $ PrimAsin floatingType + acos = mkPrimUnary $ PrimAcos floatingType + atan = mkPrimUnary $ PrimAtan floatingType + sinh = mkPrimUnary $ PrimSinh floatingType + cosh = mkPrimUnary $ PrimCosh floatingType + tanh = mkPrimUnary $ PrimTanh floatingType + asinh = mkPrimUnary $ PrimAsinh floatingType + acosh = mkPrimUnary $ PrimAcosh floatingType + atanh = mkPrimUnary $ PrimAtanh floatingType + exp = mkPrimUnary $ PrimExpFloating floatingType + sqrt = mkPrimUnary $ PrimSqrt floatingType + log = mkPrimUnary $ PrimLog floatingType + (**) = mkPrimBinary $ PrimFPow floatingType + logBase = mkPrimBinary $ PrimLogBase floatingType + |] + in + concat <$> mapM thFloating floatingTypes diff --git a/src/Data/Array/Accelerate/Classes/Fractional.hs b/src/Data/Array/Accelerate/Classes/Fractional.hs index 52bdc61e9..01eccea27 100644 --- a/src/Data/Array/Accelerate/Classes/Fractional.hs +++ b/src/Data/Array/Accelerate/Classes/Fractional.hs @@ -1,7 +1,13 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Fractional @@ -20,12 +26,16 @@ module Data.Array.Accelerate.Classes.Fractional ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Num -import Prelude ( (.) ) +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Num, Fractional ) +import qualified Data.Primitive.Vec as Prim import qualified Prelude as P @@ -44,29 +54,28 @@ import qualified Prelude as P -- type Fractional a = (Num a, P.Fractional (Exp a)) - -instance P.Fractional (Exp Half) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational - -instance P.Fractional (Exp Float) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational - -instance P.Fractional (Exp Double) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational - -instance P.Fractional (Exp CFloat) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational - -instance P.Fractional (Exp CDouble) where - (/) = mkFDiv - recip = mkRecip - fromRational = constant . P.fromRational +runQ $ + let + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + thFractional :: Name -> Q [Dec] + thFractional a = + [d| instance P.Fractional (Exp $(conT a)) where + (/) = mkPrimBinary $ PrimFDiv floatingType + recip = mkPrimUnary $ PrimRecip floatingType + fromRational = constant . P.fromRational + + instance KnownNat n => P.Fractional (Exp (Vec n $(conT a))) where + (/) = mkPrimBinary $ PrimFDiv floatingType + recip = mkPrimUnary $ PrimRecip floatingType + fromRational = constant . Vec . Prim.splat . P.fromRational + |] + in + concat <$> mapM thFractional floatingTypes diff --git a/src/Data/Array/Accelerate/Classes/FromBool.hs b/src/Data/Array/Accelerate/Classes/FromBool.hs new file mode 100644 index 000000000..23c5d994c --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/FromBool.hs @@ -0,0 +1,69 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TemplateHaskell #-} +-- | +-- Module : Data.Array.Accelerate.Classes.FromBool +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.FromBool ( + + FromBool(..), + +) where + +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type + +import Data.Array.Accelerate.Classes.Integral + +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Integral ) + + +-- | Convert from Bool to integral types +-- +-- @since 1.4.0.0 +-- +class FromBool a b where + fromBool :: Integral b => Exp a -> Exp b + + +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + thFromBool :: Name -> Q [Dec] + thFromBool b = + [d| instance FromBool Bool $(conT b) where + fromBool = mkPrimUnary $ PrimFromBool bitType integralType + + instance KnownNat n => FromBool (Vec n Bool) (Vec n $(conT b)) where + fromBool = mkPrimUnary $ PrimFromBool bitType integralType + |] + in + concat <$> mapM thFromBool integralTypes + + diff --git a/src/Data/Array/Accelerate/Classes/FromIntegral.hs b/src/Data/Array/Accelerate/Classes/FromIntegral.hs index d8678ee9d..d9dccf743 100644 --- a/src/Data/Array/Accelerate/Classes/FromIntegral.hs +++ b/src/Data/Array/Accelerate/Classes/FromIntegral.hs @@ -1,6 +1,6 @@ -{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TemplateHaskell #-} -- | @@ -19,7 +19,10 @@ module Data.Array.Accelerate.Classes.FromIntegral ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Integral @@ -37,9 +40,8 @@ class FromIntegral a b where -- | General coercion from integral types fromIntegral :: Integral a => Exp a -> Exp b --- instance {-# OVERLAPPABLE #-} (Elt a, Elt b, IsIntegral a, IsNum b) => FromIntegral a b where --- fromIntegral = mkFromIntegral - +mkFromIntegral :: (IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b +mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType -- Reify in ghci: -- @@ -49,37 +51,78 @@ class FromIntegral a b where -- messages when we don't have an instance available, rather than a "can not -- deduce IsNum..." style error (which the user can do nothing about). -- -$(runQ $ do - let - -- Get all the types that our dictionaries reify - digItOut :: Name -> Q [Name] - digItOut name = do - TyConI (DataD _ _ _ _ cons _) <- reify name - let - -- This is what a constructor such as IntegralNumType will be reified - -- as prior to GHC 8.4... - dig (NormalC _ [(_, AppT (ConT n) (VarT _))]) = digItOut n - -- ...but this is what IntegralNumType will be reified as on GHC 8.4 - -- and later, after the changes described in - -- https://ghc.haskell.org/trac/ghc/wiki/Migration/8.4#TemplateHaskellreificationchangesforGADTs - dig (ForallC _ _ (GadtC _ [(_, AppT (ConT n) (VarT _))] _)) = digItOut n - dig (GadtC _ _ (AppT (ConT _) (ConT n))) = return [n] - dig _ = error "Unexpected case generating FromIntegral instances" - -- - concat `fmap` mapM dig cons - - thFromIntegral :: Name -> Name -> Q Dec - thFromIntegral a b = - let - ty = AppT (AppT (ConT (mkName "FromIntegral")) (ConT a)) (ConT b) - dec = ValD (VarP (mkName "fromIntegral")) (NormalB (VarE (mkName f))) [] - f | a == b = "id" - | otherwise = "mkFromIntegral" - in - instanceD (return []) (return ty) [return dec] - -- - as <- digItOut ''IntegralType - bs <- digItOut ''NumType - sequence [ thFromIntegral a b | a <- as, b <- bs ] - ) +-- > -- Get all the types that our dictionaries reify +-- > digItOut :: Name -> Q [Name] +-- > digItOut name = do +-- > TyConI (DataD _ _ _ _ cons _) <- reify name +-- > let +-- > -- This is what a constructor such as IntegralNumType will be reified +-- > -- as prior to GHC 8.4... +-- > dig (NormalC _ [(_, AppT (ConT n) (VarT _))]) = digItOut n +-- > -- ...but this is what IntegralNumType will be reified as on GHC 8.4 +-- > -- and later, after the changes described in +-- > -- https://ghc.haskell.org/trac/ghc/wiki/Migration/8.4#TemplateHaskellreificationchangesforGADTs +-- > dig (ForallC _ _ (GadtC _ [(_, AppT (ConT n) (VarT _))] _)) = digItOut n +-- > dig (GadtC _ _ (AppT (ConT _) (ConT n))) = return [n] +-- > dig _ = error "Unexpected case generating FromIntegral instances" +-- > -- +-- > concat `fmap` mapM dig cons +-- +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + thFromIntegral :: Name -> Name -> Q [Dec] + thFromIntegral a b = + [d| instance FromIntegral $(conT a) $(conT b) where + fromIntegral = $(varE $ if a == b then 'id else 'mkFromIntegral ) + + instance KnownNat n => FromIntegral (Vec n $(conT a)) (Vec n $(conT b)) where + fromIntegral = $(varE $ if a == b then 'id else 'mkFromIntegral ) + |] + + thToBool :: Name -> Q [Dec] + thToBool a = + [d| +#if __GLASGOW_HASKELL__ >= 900 + -- | @since 1.4.0.0 +#endif + instance FromIntegral $(conT a) Bool where + fromIntegral = mkPrimUnary $ PrimToBool integralType bitType + +#if __GLASGOW_HASKELL__ >= 900 + -- | @since 1.4.0.0 +#endif + instance KnownNat n => FromIntegral (Vec n $(conT a)) (Vec n Bool) where + fromIntegral = mkPrimUnary $ PrimToBool integralType bitType + |] + -- + x <- concat <$> sequence [ thFromIntegral from to | from <- integralTypes, to <- numTypes ] + y <- concat <$> mapM thToBool integralTypes + return (x ++ y) diff --git a/src/Data/Array/Accelerate/Classes/Integral.hs b/src/Data/Array/Accelerate/Classes/Integral.hs index c6cadf7cc..faf61d320 100644 --- a/src/Data/Array/Accelerate/Classes/Integral.hs +++ b/src/Data/Array/Accelerate/Classes/Integral.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Integral @@ -26,7 +27,11 @@ module Data.Array.Accelerate.Classes.Integral ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Enum @@ -34,7 +39,9 @@ import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.Real () -import Prelude ( error ) +import Control.Monad +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Enum, Ord, Num, Integral ) import qualified Prelude as P @@ -42,166 +49,56 @@ import qualified Prelude as P -- type Integral a = (Enum a, Ord a, Num a, P.Integral (Exp a)) - -instance P.Integral (Exp Int) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Int8) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Int16) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Int32) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Int64) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word8) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word16) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word32) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp Word64) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CInt) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CUInt) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CLong) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CULong) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CLLong) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CULLong) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CShort) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" - -instance P.Integral (Exp CUShort) where - quot = mkQuot - rem = mkRem - div = mkIDiv - mod = mkMod - quotRem = mkQuotRem - divMod = mkDivMod - toInteger = error "Prelude.toInteger not supported for Accelerate types" +mkQuotRem :: IsIntegral (EltR t) => Exp t -> Exp t -> (Exp t, Exp t) +mkQuotRem (Exp x) (Exp y) = + let r = SmartExp $ PrimQuotRem integralType `PrimApp` SmartExp (Pair x y) + in ( mkExp $ Prj PairIdxLeft r + , mkExp $ Prj PairIdxRight r) + +mkDivMod :: IsIntegral (EltR t) => Exp t -> Exp t -> (Exp t, Exp t) +mkDivMod (Exp x) (Exp y) = + let r = SmartExp $ PrimDivMod integralType `PrimApp` SmartExp (Pair x y) + in ( mkExp $ Prj PairIdxLeft r + , mkExp $ Prj PairIdxRight r) + +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + mkIntegral :: Name -> Q [Dec] + mkIntegral a = + [d| instance P.Integral (Exp $(conT a)) where + quot = mkPrimBinary $ PrimQuot integralType + rem = mkPrimBinary $ PrimRem integralType + div = mkPrimBinary $ PrimIDiv integralType + mod = mkPrimBinary $ PrimMod integralType + quotRem = mkQuotRem + divMod = mkDivMod + toInteger = P.error "Prelude.toInteger not supported for Accelerate types" + + instance KnownNat n => P.Integral (Exp (Vec n $(conT a))) where + quot = mkPrimBinary $ PrimQuot integralType + rem = mkPrimBinary $ PrimRem integralType + div = mkPrimBinary $ PrimIDiv integralType + mod = mkPrimBinary $ PrimMod integralType + quotRem = mkQuotRem + divMod = mkDivMod + toInteger = P.error "Prelude.toInteger not supported for Accelerate types" + |] + in + concat <$> mapM mkIntegral integralTypes diff --git a/src/Data/Array/Accelerate/Classes/Num.hs b/src/Data/Array/Accelerate/Classes/Num.hs index 942b38673..351f050a0 100644 --- a/src/Data/Array/Accelerate/Classes/Num.hs +++ b/src/Data/Array/Accelerate/Classes/Num.hs @@ -1,7 +1,13 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Classes.Num @@ -15,16 +21,20 @@ module Data.Array.Accelerate.Classes.Num ( - Num, + Num, Integer, (P.+), (P.-), (P.*), P.negate, P.abs, P.signum, P.fromInteger, ) where -import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Prelude ( (.) ) +import Language.Haskell.TH hiding ( Exp ) +import Prelude hiding ( Num ) +import qualified Data.Primitive.Vec as Prim import qualified Prelude as P @@ -50,7 +60,6 @@ import qualified Prelude as P -- much, _much_ better. -- - -- | Conversion from an 'Integer'. -- -- An integer literal represents the application of the function 'fromInteger' @@ -61,215 +70,59 @@ import qualified Prelude as P -- fromInteger :: Num a => Integer -> Exp a -- fromInteger = P.fromInteger - -- | Basic numeric class -- type Num a = (Elt a, P.Num (Exp a)) +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + thNum :: Name -> Q [Dec] + thNum a = + [d| instance P.Num (Exp $(conT a)) where + (+) = mkPrimBinary $ PrimAdd numType + (-) = mkPrimBinary $ PrimSub numType + (*) = mkPrimBinary $ PrimMul numType + negate = mkPrimUnary $ PrimNeg numType + abs = mkPrimUnary $ PrimAbs numType + signum = mkPrimUnary $ PrimSig numType + fromInteger = constant . P.fromInteger + + instance KnownNat n => P.Num (Exp (Vec n $(conT a))) where + (+) = mkPrimBinary $ PrimAdd numType + (-) = mkPrimBinary $ PrimSub numType + (*) = mkPrimBinary $ PrimMul numType + negate = mkPrimUnary $ PrimNeg numType + abs = mkPrimUnary $ PrimAbs numType + signum = mkPrimUnary $ PrimSig numType + fromInteger = constant . Vec . Prim.splat . P.fromInteger + |] + in + concat <$> mapM thNum numTypes -instance P.Num (Exp Int) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Int8) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Int16) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Int32) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Int64) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word8) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word16) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word32) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Word64) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CInt) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CUInt) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CLong) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CULong) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CLLong) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CULLong) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CShort) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CUShort) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Half) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Float) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp Double) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CFloat) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger - -instance P.Num (Exp CDouble) where - (+) = mkAdd - (-) = mkSub - (*) = mkMul - negate = mkNeg - abs = mkAbs - signum = mkSig - fromInteger = constant . P.fromInteger diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs b/src/Data/Array/Accelerate/Classes/Ord.hs index fa80dd9d2..9aad3c2a5 100644 --- a/src/Data/Array/Accelerate/Classes/Ord.hs +++ b/src/Data/Array/Accelerate/Classes/Ord.hs @@ -1,9 +1,11 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} @@ -24,27 +26,28 @@ module Data.Array.Accelerate.Classes.Ord ( Ord(..), - Ordering(..), pattern LT_, pattern EQ_, pattern GT_, + Ordering, pattern LT, pattern EQ, pattern GT, ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Analysis.Match -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Error import Data.Array.Accelerate.Pattern.Ordering +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type - --- We must hide (==), as that operator is used for the literals 0, 1 and 2 in the pattern synonyms for Ordering. --- As RebindableSyntax is enabled, a literal pattern is compiled to a call to (==), meaning that the Prelude.(==) should be in scope as (==). -import Data.Array.Accelerate.Classes.Eq hiding ( (==) ) -import qualified Data.Array.Accelerate.Classes.Eq as A +import {-# SOURCE #-} Data.Array.Accelerate.Classes.VOrd import Data.Char import Language.Haskell.TH.Extra hiding ( Exp ) -import Prelude ( ($), (>>=), Ordering(..), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) +import Prelude ( ($), Num(..), Maybe(..), String, show, error, unlines, return, concat, map, mapM ) import Text.Printf import qualified Prelude as P @@ -54,7 +57,7 @@ infix 4 > infix 4 <= infix 4 >= --- | The 'Ord' class for totally ordered datatypes +-- | The 'Ord' class for totally ordered data types -- class Eq a => Ord a where {-# MINIMAL (<=) | compare #-} @@ -66,41 +69,23 @@ class Eq a => Ord a where max :: Exp a -> Exp a -> Exp a compare :: Exp a -> Exp a -> Exp Ordering - x < y = if compare x y A.== constant LT then constant True else constant False - x <= y = if compare x y A.== constant GT then constant False else constant True - x > y = if compare x y A.== constant GT then constant True else constant False - x >= y = if compare x y A.== constant LT then constant False else constant True + x < y = cond (compare x y == LT) True False + x <= y = cond (compare x y == GT) False True + x > y = cond (compare x y == GT) True False + x >= y = cond (compare x y == LT) False True - min x y = if x <= y then x else y - max x y = if x <= y then y else x + min x y = cond (x <= y) x y + max x y = cond (x <= y) y x - compare x y = - if x A.== y then constant EQ else - if x <= y then constant LT - else constant GT + compare x y + = cond (x == y) EQ + $ cond (x <= y) LT + {- else -} GT --- Local redefinition for use with RebindableSyntax (pulled forward from Prelude.hs) +-- Local redefinition to prevent cyclic imports -- -ifThenElse :: Elt a => Exp Bool -> Exp a -> Exp a -> Exp a -ifThenElse (Exp c) (Exp x) (Exp y) = Exp $ SmartExp $ Cond (mkCoerce' c) x y - -instance Ord () where - (<) _ _ = constant False - (>) _ _ = constant False - (>=) _ _ = constant True - (<=) _ _ = constant True - min _ _ = constant () - max _ _ = constant () - compare _ _ = constant EQ - -instance Ord Z where - (<) _ _ = constant False - (>) _ _ = constant False - (<=) _ _ = constant True - (>=) _ _ = constant True - min _ _ = constant Z - max _ _ = constant Z - +cond :: Elt a => Exp Bool -> Exp a -> Exp a -> Exp a +cond (Exp c) (Exp x) (Exp y) = Exp $ SmartExp $ Cond (mkCoerce' c) x y -- Instances of 'Prelude.Ord' (mostly) don't make sense with the standard -- signatures as the return type is fixed to 'Bool'. This instance is provided @@ -118,7 +103,7 @@ instance Ord a => P.Ord (Exp a) where min = min max = max -preludeError :: String -> String -> a +preludeError :: HasCallStack => String -> String -> a preludeError x y = error $ unlines [ printf "Prelude.%s applied to EDSL types: use Data.Array.Accelerate.%s instead" x y @@ -137,11 +122,13 @@ runQ $ do , ''Int16 , ''Int32 , ''Int64 + , ''Int128 , ''Word , ''Word8 , ''Word16 , ''Word32 , ''Word64 + , ''Word128 ] floatingTypes :: [Name] @@ -149,6 +136,7 @@ runQ $ do [ ''Half , ''Float , ''Double + , ''Float128 ] nonNumTypes :: [Name] @@ -156,53 +144,36 @@ runQ $ do [ ''Char ] - cTypes :: [Name] - cTypes = - [ ''CInt - , ''CUInt - , ''CLong - , ''CULong - , ''CLLong - , ''CULLong - , ''CShort - , ''CUShort - , ''CChar - , ''CUChar - , ''CSChar - , ''CFloat - , ''CDouble - ] - mkPrim :: Name -> Q [Dec] mkPrim t = [d| instance Ord $(conT t) where - (<) = mkLt - (>) = mkGt - (<=) = mkLtEq - (>=) = mkGtEq - min = mkMin - max = mkMax + (<) = mkPrimBinary $ PrimLt scalarType + (>) = mkPrimBinary $ PrimGt scalarType + (<=) = mkPrimBinary $ PrimLtEq scalarType + (>=) = mkPrimBinary $ PrimGtEq scalarType + min = mkPrimBinary $ PrimMin scalarType + max = mkPrimBinary $ PrimMax scalarType |] - mkLt' :: [ExpQ] -> [ExpQ] -> ExpQ - mkLt' [x] [y] = [| $x < $y |] - mkLt' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLt' xs ys) ) |] - mkLt' _ _ = error "mkLt'" + mkLt :: [ExpQ] -> [ExpQ] -> ExpQ + mkLt [x] [y] = [| $x < $y |] + mkLt (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLt xs ys) ) |] + mkLt _ _ = error "mkLt" - mkGt' :: [ExpQ] -> [ExpQ] -> ExpQ - mkGt' [x] [y] = [| $x > $y |] - mkGt' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGt' xs ys) ) |] - mkGt' _ _ = error "mkGt'" + mkGt :: [ExpQ] -> [ExpQ] -> ExpQ + mkGt [x] [y] = [| $x > $y |] + mkGt (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGt xs ys) ) |] + mkGt _ _ = error "mkGt" - mkLtEq' :: [ExpQ] -> [ExpQ] -> ExpQ - mkLtEq' [x] [y] = [| $x < $y |] - mkLtEq' (x:xs) (y:ys) = [| $x < $y || ( $x A.== $y && $(mkLtEq' xs ys) ) |] - mkLtEq' _ _ = error "mkLtEq'" + mkLtEq :: [ExpQ] -> [ExpQ] -> ExpQ + mkLtEq [x] [y] = [| $x <= $y |] + mkLtEq (x:xs) (y:ys) = [| $x < $y || ( $x == $y && $(mkLtEq xs ys) ) |] + mkLtEq _ _ = error "mkLtEq" - mkGtEq' :: [ExpQ] -> [ExpQ] -> ExpQ - mkGtEq' [x] [y] = [| $x > $y |] - mkGtEq' (x:xs) (y:ys) = [| $x > $y || ( $x A.== $y && $(mkGtEq' xs ys) ) |] - mkGtEq' _ _ = error "mkGtEq'" + mkGtEq :: [ExpQ] -> [ExpQ] -> ExpQ + mkGtEq [x] [y] = [| $x >= $y |] + mkGtEq (x:xs) (y:ys) = [| $x > $y || ( $x == $y && $(mkGtEq xs ys) ) |] + mkGtEq _ _ = error "mkGtEq" mkTup :: Int -> Q [Dec] mkTup n = @@ -214,18 +185,36 @@ runQ $ do pat vs = conP (mkName ('T':show n)) (map varP vs) in [d| instance $cst => Ord $res where - $(pat xs) < $(pat ys) = $( mkLt' (map varE xs) (map varE ys) ) - $(pat xs) > $(pat ys) = $( mkGt' (map varE xs) (map varE ys) ) - $(pat xs) >= $(pat ys) = $( mkGtEq' (map varE xs) (map varE ys) ) - $(pat xs) <= $(pat ys) = $( mkLtEq' (map varE xs) (map varE ys) ) + $(pat xs) < $(pat ys) = $( mkLt (map varE xs) (map varE ys) ) + $(pat xs) > $(pat ys) = $( mkGt (map varE xs) (map varE ys) ) + $(pat xs) >= $(pat ys) = $( mkGtEq (map varE xs) (map varE ys) ) + $(pat xs) <= $(pat ys) = $( mkLtEq (map varE xs) (map varE ys) ) |] is <- mapM mkPrim integralTypes fs <- mapM mkPrim floatingTypes ns <- mapM mkPrim nonNumTypes - cs <- mapM mkPrim cTypes ts <- mapM mkTup [2..16] - return $ concat (concat [is,fs,ns,cs,ts]) + return $ concat (concat [is,fs,ns,ts]) + + +instance Ord () where + (<) _ _ = constant False + (>) _ _ = constant False + (>=) _ _ = constant True + (<=) _ _ = constant True + min _ _ = constant () + max _ _ = constant () + compare _ _ = constant EQ + +instance Ord Z where + (<) _ _ = constant False + (>) _ _ = constant False + (<=) _ _ = constant True + (>=) _ _ = constant True + min _ _ = constant Z + max _ _ = constant Z + compare _ _ = constant EQ instance Ord sh => Ord (sh :. Int) where x <= y = indexHead x <= indexHead y && indexTail x <= indexTail y @@ -247,3 +236,19 @@ instance Ord Ordering where min x y = mkCoerce $ min (mkCoerce x) (mkCoerce y :: Exp TAG) max x y = mkCoerce $ max (mkCoerce x) (mkCoerce y :: Exp TAG) +instance VOrd n a => Ord (Vec n a) where + (<) = vcmp (<*) + (>) = vcmp (>*) + (<=) = vcmp (<=*) + (>=) = vcmp (>=*) + +vcmp :: forall n a. VOrd n a + => (Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool)) + -> (Exp (Vec n a) -> Exp (Vec n a) -> Exp Bool) +vcmp cmp x y = + let go [u] [_] = u + go (u:us) (v:vs) = u || (v && go us vs) + go _ _ = internalError "unexpected vector encoding" + in + go (unpack (cmp x y)) (unpack (x ==* y)) + diff --git a/src/Data/Array/Accelerate/Classes/Ord.hs-boot b/src/Data/Array/Accelerate/Classes/Ord.hs-boot new file mode 100644 index 000000000..5a1354c38 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/Ord.hs-boot @@ -0,0 +1,43 @@ +{-# LANGUAGE NoImplicitPrelude #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VOrd +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.Ord ( + + Ord(..), + Ordering, + +) where + +import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Pattern.Ordering + + +class Eq a => Ord a where + {-# MINIMAL (<=) | compare #-} + (<) :: Exp a -> Exp a -> Exp Bool + (>) :: Exp a -> Exp a -> Exp Bool + (<=) :: Exp a -> Exp a -> Exp Bool + (>=) :: Exp a -> Exp a -> Exp Bool + min :: Exp a -> Exp a -> Exp a + max :: Exp a -> Exp a -> Exp a + compare :: Exp a -> Exp a -> Exp Ordering + + (<) = undefined + (<=) = undefined + (>) = undefined + (>=) = undefined + + min = undefined + max = undefined + + compare = undefined + diff --git a/src/Data/Array/Accelerate/Classes/Rational.hs b/src/Data/Array/Accelerate/Classes/Rational.hs index 7ec238459..bb5b05cc2 100644 --- a/src/Data/Array/Accelerate/Classes/Rational.hs +++ b/src/Data/Array/Accelerate/Classes/Rational.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} -- | -- Module : Data.Array.Accelerate.Classes.Rational -- Copyright : [2016..2020] The Accelerate Team @@ -19,7 +21,7 @@ import Data.Array.Accelerate.Data.Ratio import Data.Array.Accelerate.Data.Bits import Data.Array.Accelerate.Language -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type @@ -29,7 +31,9 @@ import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.RealFloat +import Data.Array.Accelerate.Classes.RealFrac +import Data.Kind import Prelude ( ($) ) @@ -40,57 +44,67 @@ import Prelude ( ($) ) -- package. -- class (Num a, Ord a) => Rational a where + type Embedding a :: Type + -- | Convert a number to the quotient of two integers -- - toRational :: (FromIntegral Int64 b, Integral b) => Exp a -> Exp (Ratio b) - -instance Rational Int where toRational = integralToRational -instance Rational Int8 where toRational = integralToRational -instance Rational Int16 where toRational = integralToRational -instance Rational Int32 where toRational = integralToRational -instance Rational Int64 where toRational = integralToRational -instance Rational Word where toRational = integralToRational -instance Rational Word8 where toRational = integralToRational -instance Rational Word16 where toRational = integralToRational -instance Rational Word32 where toRational = integralToRational -instance Rational Word64 where toRational = integralToRational - -instance Rational Half where toRational = floatingToRational -instance Rational Float where toRational = floatingToRational -instance Rational Double where toRational = floatingToRational - + toRational :: (FromIntegral (Embedding a) b, Integral b) => Exp a -> Exp (Ratio b) + +instance Rational Int where type Embedding Int = Int; toRational = integralToRational +instance Rational Int8 where type Embedding Int8 = Int8; toRational = integralToRational +instance Rational Int16 where type Embedding Int16 = Int16; toRational = integralToRational +instance Rational Int32 where type Embedding Int32 = Int32; toRational = integralToRational +instance Rational Int64 where type Embedding Int64 = Int64; toRational = integralToRational +instance Rational Int128 where type Embedding Int128 = Int128; toRational = integralToRational +instance Rational Word where type Embedding Word = Word; toRational = integralToRational +instance Rational Word8 where type Embedding Word8 = Word8; toRational = integralToRational +instance Rational Word16 where type Embedding Word16 = Word16; toRational = integralToRational +instance Rational Word32 where type Embedding Word32 = Word32; toRational = integralToRational +instance Rational Word64 where type Embedding Word64 = Word64; toRational = integralToRational +instance Rational Word128 where type Embedding Word128 = Word128; toRational = integralToRational + +instance Rational Half where type Embedding Half = Int16; toRational = floatingToRational +instance Rational Float where type Embedding Float = Int32; toRational = floatingToRational +instance Rational Double where type Embedding Double = Int64; toRational = floatingToRational +instance Rational Float128 where type Embedding Float128 = Int128; toRational = floatingToRational integralToRational - :: (Integral a, Integral b, FromIntegral a Int64, FromIntegral Int64 b) + :: (Integral a, Integral b, FromIntegral a b) => Exp a -> Exp (Ratio b) -integralToRational x = fromIntegral (fromIntegral x :: Exp Int64) :% 1 +integralToRational x = fromIntegral x :% 1 floatingToRational - :: (RealFloat a, Integral b, FromIntegral Int64 b) + :: (RealFloat a, Integral b, FromIntegral (Significand a) b, FromIntegral Int (Significand a), FiniteBits (Significand a), FromIntegral (Exponent a) (Significand a)) => Exp a -> Exp (Ratio b) floatingToRational x = fromIntegral u :% fromIntegral v where - (m, e) = decodeFloat x - (n, d) = elimZeros m (negate e) - u :% v = cond (e >= 0) ((m `shiftL` e) :% 1) $ - cond (m .&. 1 == 0) (n :% shiftL 1 d) $ - (m :% shiftL 1 (negate e)) + T2 m e = decodeFloat x + T2 n d = elimZeros m ne' + ne' = negate e' + e' = fromIntegral e + u :% v = cond (e' >= 0) ((m `shiftL` e') :% 1) $ + cond (m .&. 1 == 0) (n :% shiftL 1 d) $ + (m :% shiftL 1 ne') -- Stolen from GHC.Float.ConversionUtils -- Double mantissa have 53 bits, which fits in an Int64 -- -elimZeros :: Exp Int64 -> Exp Int -> (Exp Int64, Exp Int) -elimZeros x y = (u, v) +elimZeros + :: forall e. (Num e, Ord e, FiniteBits e) + => Exp e + -> Exp e + -> Exp (e, e) +elimZeros x y = T2 u v where T3 _ u v = while (\(T3 p _ _) -> p) elim (T3 moar x y) kthxbai = constant False moar = constant True - elim :: Exp (Bool, Int64, Int) -> Exp (Bool, Int64, Int) + elim :: Exp (Bool, e, e) -> Exp (Bool, e, e) elim (T3 _ n e) = - let t = countTrailingZeros (fromIntegral n :: Exp Word8) + let t = countTrailingZeros n in cond (e <= t) (T3 kthxbai (shiftR n e) 0) $ cond (t < 8) (T3 kthxbai (shiftR n t) (e-t)) $ diff --git a/src/Data/Array/Accelerate/Classes/Real.hs b/src/Data/Array/Accelerate/Classes/Real.hs index a6bd1b185..aa38158cf 100644 --- a/src/Data/Array/Accelerate/Classes/Real.hs +++ b/src/Data/Array/Accelerate/Classes/Real.hs @@ -34,8 +34,8 @@ type Real a = (Num a, Ord a, P.Real (Exp a)) -- Instances of 'Real' don't make sense in Accelerate at the moment. These are -- only provided to fulfil superclass constraints; e.g. Integral. -- --- We won't need `toRational' until we support rational numbers in AP --- computations. +-- We won't need `toRational' until we support rational numbers in scalar +-- expressions. -- instance (Num a, Ord a) => P.Real (Exp a) where toRational diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs b/src/Data/Array/Accelerate/Classes/RealFloat.hs index 0b3366ec2..a24d55c9f 100644 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs +++ b/src/Data/Array/Accelerate/Classes/RealFloat.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -7,8 +7,10 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} +{-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Classes.RealFloat -- Copyright : [2016..2020] The Accelerate Team @@ -22,13 +24,16 @@ module Data.Array.Accelerate.Classes.RealFloat ( RealFloat(..), + defaultProperFraction, ) where -import Data.Array.Accelerate.Error -import Data.Array.Accelerate.Language ( cond, while ) -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask, PrimMask ) +import Data.Array.Accelerate.Language ( (^), cond, while ) +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt ( Elt(..) ) +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Data.Bits @@ -36,147 +41,254 @@ import Data.Array.Accelerate.Data.Bits import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.FromIntegral +import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.RealFrac +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Classes.VOrd -import Data.Text.Lazy.Builder -import Formatting +import Data.Coerce +import Data.Kind +import Data.Type.Equality import Text.Printf -import Prelude ( (.), ($), String, error, undefined, unlines, otherwise ) +import Prelude ( (.), ($), String, error, undefined, unlines ) import qualified Prelude as P -- | Efficient, machine-independent access to the components of a floating-point -- number -- -class (RealFrac a, Floating a) => RealFloat a where +class (RealFrac a, Floating a, Integral (Exponent a)) => RealFloat a where + type Exponent a :: Type + -- | The radix of the representation (often 2) (constant) - floatRadix :: Exp a -> Exp Int64 -- Integer - default floatRadix :: P.RealFloat a => Exp a -> Exp Int64 - floatRadix _ = P.fromInteger (P.floatRadix (undefined::a)) + floatRadix :: Exp a -> Exp Int -- | The number of digits of 'floatRadix' in the significand (constant) - floatDigits :: Exp a -> Exp Int - default floatDigits :: P.RealFloat a => Exp a -> Exp Int - floatDigits _ = constant (P.floatDigits (undefined::a)) + floatDigits :: Exp a -> Exp Int -- | The lowest and highest values the exponent may assume (constant) - floatRange :: Exp a -> (Exp Int, Exp Int) - default floatRange :: P.RealFloat a => Exp a -> (Exp Int, Exp Int) - floatRange _ = let (m,n) = P.floatRange (undefined::a) - in (constant m, constant n) + floatRange :: Exp a -> Exp (Int, Int) -- | Return the significand and an appropriately scaled exponent. If -- @(m,n) = 'decodeFloat' x@ then @x = m*b^^n@, where @b@ is the -- floating-point radix ('floatRadix'). Furthermore, either @m@ and @n@ are -- both zero, or @b^(d-1) <= 'abs' m < b^d@, where @d = 'floatDigits' x@. - decodeFloat :: Exp a -> (Exp Int64, Exp Int) -- Integer + decodeFloat :: Exp a -> Exp (Significand a, Exponent a) -- | Inverse of 'decodeFloat' - encodeFloat :: Exp Int64 -> Exp Int -> Exp a -- Integer - default encodeFloat :: (FromIntegral Int a, FromIntegral Int64 a) => Exp Int64 -> Exp Int -> Exp a - encodeFloat x e = fromIntegral x * (fromIntegral (floatRadix (undefined :: Exp a)) ** fromIntegral e) + encodeFloat :: Exp (Significand a) -> Exp (Exponent a) -> Exp a -- | Corresponds to the second component of 'decodeFloat' - exponent :: Exp a -> Exp Int - exponent x = let (m,n) = decodeFloat x - in cond (m == 0) - 0 - (n + floatDigits x) + exponent :: Exp a -> Exp (Exponent a) - -- | Corresponds to the first component of 'decodeFloat' - significand :: Exp a -> Exp a - significand x = let (m,_) = decodeFloat x - in encodeFloat m (negate (floatDigits x)) + -- | The first component of 'decodeFloat', scaled to lie in the open interval (-1,1). + significand :: Exp a -> Exp a -- | Multiply a floating point number by an integer power of the radix - scaleFloat :: Exp Int -> Exp a -> Exp a - scaleFloat k x = - cond (k == 0 || isFix) x - $ encodeFloat m (n + clamp b) - where - isFix = x == 0 || isNaN x || isInfinite x - (m,n) = decodeFloat x - (l,h) = floatRange x - d = floatDigits x - b = h - l + 4*d - -- n+k may overflow, which would lead to incorrect results, hence we clamp - -- the scaling parameter. If (n+k) would be larger than h, (n + clamp b k) - -- must be too, similar for smaller than (l-d). - clamp bd = max (-bd) (min bd k) + scaleFloat :: Exp Int -> Exp a -> Exp a -- | 'True' if the argument is an IEEE \"not-a-number\" (NaN) value - isNaN :: Exp a -> Exp Bool + isNaN :: (BitOrMask (EltR a) ~ EltR b) => Exp a -> Exp b -- | 'True' if the argument is an IEEE infinity or negative-infinity - isInfinite :: Exp a -> Exp Bool + isInfinite :: (BitOrMask (EltR a) ~ EltR b) => Exp a -> Exp b - -- | 'True' if the argument is too small to be represented in normalized - -- format - isDenormalized :: Exp a -> Exp Bool + -- | 'True' if the argument is too small to be represented in normalized format + isDenormalized :: (BitOrMask (EltR a) ~ EltR b) => Exp a -> Exp b -- | 'True' if the argument is an IEEE negative zero - isNegativeZero :: Exp a -> Exp Bool + isNegativeZero :: (BitOrMask (EltR a) ~ EltR b) => Exp a -> Exp b -- | 'True' if the argument is an IEEE floating point number - isIEEE :: Exp a -> Exp Bool - default isIEEE :: P.RealFloat a => Exp a -> Exp Bool - isIEEE _ = constant (P.isIEEE (undefined::a)) + isIEEE :: Exp a -> Exp Bool -- | A version of arctangent taking two real floating-point arguments. -- For real floating @x@ and @y@, @'atan2' y x@ computes the angle (from the -- positive x-axis) of the vector from the origin to the point @(x,y)@. -- @'atan2' y x@ returns a value in the range [@-pi@, @pi@]. - atan2 :: Exp a -> Exp a -> Exp a - - -instance RealFloat Half where + atan2 :: Exp a -> Exp a -> Exp a + + +instance RealFrac Float16 where + type Significand Float16 = Int16 + properFraction = defaultProperFraction + +instance RealFrac Float32 where + type Significand Float32 = Int32 + properFraction = defaultProperFraction + +instance RealFrac Float64 where + type Significand Float64 = Int64 + properFraction = defaultProperFraction + +instance RealFrac Float128 where + type Significand Float128 = Int128 + properFraction = defaultProperFraction + +instance KnownNat n => RealFrac (Vec n Float16) where + type Significand (Vec n Float16) = Vec n Int16 + properFraction = defaultProperFraction' + +instance KnownNat n => RealFrac (Vec n Float32) where + type Significand (Vec n Float32) = Vec n Int32 + properFraction = defaultProperFraction' + +instance KnownNat n => RealFrac (Vec n Float64) where + type Significand (Vec n Float64) = Vec n Int64 + properFraction = defaultProperFraction' + +instance KnownNat n => RealFrac (Vec n Float128) where + type Significand (Vec n Float128) = Vec n Int128 + properFraction = defaultProperFraction' + +instance RealFloat Float16 where + type Exponent Float16 = Int + floatRadix = defaultFloatRadix + floatDigits = defaultFloatDigits + floatRange = defaultFloatRange + encodeFloat s = mkCoerce . defaultEncodeFloat @1 @Float16 (mkBitcast s) . mkBitcast + exponent = mkCoerce . defaultExponent . mkBitcast @(Vec 1 Float16) + significand = mkCoerce . defaultSignificand . mkBitcast @(Vec 1 Float16) + scaleFloat k = mkCoerce . defaultScaleFloat k . mkBitcast @(Vec 1 Float16) atan2 = mkAtan2 isNaN = mkIsNaN isInfinite = mkIsInfinite - isDenormalized = ieee754 "isDenormalized" (ieee754_f16_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f16_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f16_decode (mkBitcast x) - in (fromIntegral m, n)) - -instance RealFloat Float where + isIEEE = defaultIsIEEE + isDenormalized = mkCoerce . ieee754_f16_is_denormalized . mkBitcast @(Vec 1 Word16) + isNegativeZero = mkCoerce . ieee754_f16_is_negative_zero . mkBitcast @(Vec 1 Word16) + decodeFloat = mkCoerce . ieee754_f16_decode . mkBitcast @(Vec 1 Word16) + +instance RealFloat Float32 where + type Exponent Float32 = Int + floatRadix = defaultFloatRadix + floatDigits = defaultFloatDigits + floatRange = defaultFloatRange + encodeFloat s = mkCoerce . defaultEncodeFloat @1 @Float32 (mkBitcast s) . mkBitcast + exponent = mkCoerce . defaultExponent . mkBitcast @(Vec 1 Float32) + significand = mkCoerce . defaultSignificand . mkBitcast @(Vec 1 Float32) + scaleFloat k = mkCoerce . defaultScaleFloat k . mkBitcast @(Vec 1 Float32) atan2 = mkAtan2 isNaN = mkIsNaN isInfinite = mkIsInfinite - isDenormalized = ieee754 "isDenormalized" (ieee754_f32_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f32_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f32_decode (mkBitcast x) - in (fromIntegral m, n)) - -instance RealFloat Double where + isIEEE = defaultIsIEEE + isDenormalized = mkCoerce . ieee754_f32_is_denormalized . mkBitcast @(Vec 1 Word32) + isNegativeZero = mkCoerce . ieee754_f32_is_negative_zero . mkBitcast @(Vec 1 Word32) + decodeFloat = mkCoerce . ieee754_f32_decode . mkBitcast @(Vec 1 Word32) + +instance RealFloat Float64 where + type Exponent Float64 = Int + floatRadix = defaultFloatRadix + floatDigits = defaultFloatDigits + floatRange = defaultFloatRange + encodeFloat s = mkCoerce . defaultEncodeFloat @1 @Float64 (mkBitcast s) . mkBitcast + exponent = mkCoerce . defaultExponent . mkBitcast @(Vec 1 Float64) + significand = mkCoerce . defaultSignificand . mkBitcast @(Vec 1 Float64) + scaleFloat k = mkCoerce . defaultScaleFloat k . mkBitcast @(Vec 1 Float64) atan2 = mkAtan2 isNaN = mkIsNaN isInfinite = mkIsInfinite - isDenormalized = ieee754 "isDenormalized" (ieee754_f64_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f64_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f64_decode (mkBitcast x) - in (m, n)) + isIEEE = defaultIsIEEE + isDenormalized = mkCoerce . ieee754_f64_is_denormalized . mkBitcast @(Vec 1 Word64) + isNegativeZero = mkCoerce . ieee754_f64_is_negative_zero . mkBitcast @(Vec 1 Word64) + decodeFloat = mkCoerce . ieee754_f64_decode . mkBitcast @(Vec 1 Word64) + +instance RealFloat Float128 where + type Exponent Float128 = Int + floatRadix = defaultFloatRadix + floatDigits = defaultFloatDigits + floatRange = defaultFloatRange + encodeFloat s = mkCoerce . defaultEncodeFloat @1 @Float128 (mkBitcast s) . mkBitcast + exponent = mkCoerce . defaultExponent . mkBitcast @(Vec 1 Float128) + significand = mkCoerce . defaultSignificand . mkBitcast @(Vec 1 Float128) + scaleFloat k = mkCoerce . defaultScaleFloat k . mkBitcast @(Vec 1 Float128) + atan2 = mkAtan2 + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isIEEE = defaultIsIEEE + isDenormalized = mkCoerce . ieee754_f128_is_denormalized . mkBitcast @(Vec 1 Word128) + isNegativeZero = mkCoerce . ieee754_f128_is_negative_zero . mkBitcast @(Vec 1 Word128) + decodeFloat = mkCoerce . ieee754_f128_decode . mkBitcast @(Vec 1 Word128) + +instance KnownNat n => RealFloat (Vec n Float16) where + type Exponent (Vec n Float16) = Vec n Int + floatRadix _ = defaultFloatRadix (undefined :: Exp Float16) + floatDigits _ = defaultFloatDigits (undefined :: Exp Float16) + floatRange _ = defaultFloatRange (undefined :: Exp Float16) + decodeFloat = ieee754_f16_decode . mkBitcast' + encodeFloat = defaultEncodeFloat + exponent = defaultExponent + significand = defaultSignificand + scaleFloat = defaultScaleFloat + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isDenormalized = coerce . ieee754_f16_is_denormalized . mkBitcast' + isNegativeZero = coerce . ieee754_f16_is_negative_zero . mkBitcast' + isIEEE _ = defaultIsIEEE (undefined :: Exp Float16) + atan2 = mkAtan2 + -instance RealFloat CFloat where +instance KnownNat n => RealFloat (Vec n Float32) where + type Exponent (Vec n Float32) = Vec n Int + floatRadix _ = defaultFloatRadix (undefined :: Exp Float32) + floatDigits _ = defaultFloatDigits (undefined :: Exp Float32) + floatRange _ = defaultFloatRange (undefined :: Exp Float32) + decodeFloat = ieee754_f32_decode . mkBitcast' + encodeFloat = defaultEncodeFloat + exponent = defaultExponent + significand = defaultSignificand + scaleFloat = defaultScaleFloat + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isDenormalized = coerce . ieee754_f32_is_denormalized . mkBitcast' + isNegativeZero = coerce . ieee754_f32_is_negative_zero . mkBitcast' + isIEEE _ = defaultIsIEEE (undefined :: Exp Float32) atan2 = mkAtan2 - isNaN = mkIsNaN . mkBitcast @Float - isInfinite = mkIsInfinite . mkBitcast @Float - isDenormalized = ieee754 "isDenormalized" (ieee754_f32_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f32_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f32_decode (mkBitcast x) - in (fromIntegral m, n)) - encodeFloat x e = mkBitcast (encodeFloat @Float x e) - -instance RealFloat CDouble where + +instance KnownNat n => RealFloat (Vec n Float64) where + type Exponent (Vec n Float64) = Vec n Int + floatRadix _ = defaultFloatRadix (undefined :: Exp Float64) + floatDigits _ = defaultFloatDigits (undefined :: Exp Float64) + floatRange _ = defaultFloatRange (undefined :: Exp Float64) + decodeFloat = ieee754_f64_decode . mkBitcast' + encodeFloat = defaultEncodeFloat + exponent = defaultExponent + significand = defaultSignificand + scaleFloat = defaultScaleFloat + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isDenormalized = coerce . ieee754_f64_is_denormalized . mkBitcast' + isNegativeZero = coerce . ieee754_f64_is_negative_zero . mkBitcast' + isIEEE _ = defaultIsIEEE (undefined :: Exp Float64) atan2 = mkAtan2 - isNaN = mkIsNaN . mkBitcast @Double - isInfinite = mkIsInfinite . mkBitcast @Double - isDenormalized = ieee754 "isDenormalized" (ieee754_f64_is_denormalized . mkBitcast) - isNegativeZero = ieee754 "isNegativeZero" (ieee754_f64_is_negative_zero . mkBitcast) - decodeFloat = ieee754 "decodeFloat" (\x -> let T2 m n = ieee754_f64_decode (mkBitcast x) - in (m, n)) - encodeFloat x e = mkBitcast (encodeFloat @Double x e) + + +instance KnownNat n => RealFloat (Vec n Float128) where + type Exponent (Vec n Float128) = Vec n Int + floatRadix _ = defaultFloatRadix (undefined :: Exp Float128) + floatDigits _ = defaultFloatDigits (undefined :: Exp Float128) + floatRange _ = defaultFloatRange (undefined :: Exp Float128) + decodeFloat = ieee754_f128_decode . mkBitcast' + encodeFloat = defaultEncodeFloat + exponent = defaultExponent + significand = defaultSignificand + scaleFloat = defaultScaleFloat + isNaN = mkIsNaN + isInfinite = mkIsInfinite + isDenormalized = coerce . ieee754_f128_is_denormalized . mkBitcast' + isNegativeZero = coerce . ieee754_f128_is_negative_zero . mkBitcast' + isIEEE _ = defaultIsIEEE (undefined :: Exp Float128) + atan2 = mkAtan2 + +mkIsNaN :: (IsFloating (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp b +mkIsNaN = mkPrimUnary $ PrimIsNaN floatingType + +mkIsInfinite :: (IsFloating (EltR t), BitOrMask (EltR t) ~ EltR b) => Exp t -> Exp b +mkIsInfinite = mkPrimUnary $ PrimIsInfinite floatingType + +mkAtan2 :: IsFloating (EltR t) => Exp t -> Exp t -> Exp t +mkAtan2 = mkPrimBinary $ PrimAtan2 floatingType -- To satisfy superclass constraints @@ -203,10 +315,127 @@ preludeError x ] -ieee754 :: forall a b. HasCallStack => P.RealFloat a => Builder -> (Exp a -> b) -> Exp a -> b -ieee754 name f x - | P.isIEEE (undefined::a) = f x - | otherwise = internalError (builder % ": Not implemented for non-IEEE floating point") name +-- GHC's type level natural normalisation isn't strong enough to deduce (n * 32) == (n * 32) +mkBitcast' + :: forall b a n. (IsScalar (VecR n a), IsScalar (VecR n b), BitSizeEq (EltR a) (EltR b)) + => Exp (Vec n a) + -> Exp (Vec n b) +mkBitcast' (Exp a) = mkExp $ Bitcast (scalarType @(VecR n a)) (scalarType @(VecR n b)) a + +defaultFloatRadix :: forall a. P.RealFloat a => Exp a -> Exp Int +defaultFloatRadix _ = P.fromInteger (P.floatRadix (undefined::a)) + +defaultFloatDigits :: forall a. P.RealFloat a => Exp a -> Exp Int +defaultFloatDigits _ = constant (P.floatDigits (undefined::a)) + +defaultFloatRange :: forall a. P.RealFloat a => Exp a -> Exp (Int, Int) +defaultFloatRange _ = constant (P.floatRange (undefined::a)) + +defaultIsIEEE :: forall a. P.RealFloat a => Exp a -> Exp Bool +defaultIsIEEE _ = constant (P.isIEEE (undefined::a)) + +defaultEncodeFloat + :: forall n a. (SIMD n a, RealFloat a, RealFloat (Vec n a), FromIntegral Int a, FromIntegral (Significand (Vec n a)) (Vec n a), FromIntegral (Exponent (Vec n a)) (Vec n a)) + => Exp (Significand (Vec n a)) + -> Exp (Exponent (Vec n a)) + -> Exp (Vec n a) +defaultEncodeFloat x e = + let d = splat (fromIntegral (floatRadix (undefined :: Exp a))) + in fromIntegral x * (d ** fromIntegral e) + +defaultExponent + :: (RealFloat (Vec n a), Significand (Vec n a) ~ Vec n s, Exponent (Vec n a) ~ Vec n Int, VEq n a, VEq n s) + => Exp (Vec n a) + -> Exp (Vec n Int) +defaultExponent x = + let T2 m n = decodeFloat x + d = splat (floatDigits x) + in + select (m ==* 0) 0 (n + d) + +defaultSignificand + :: (RealFloat (Vec n a), Exponent (Vec n a) ~ Vec n Int, KnownNat n) + => Exp (Vec n a) + -> Exp (Vec n a) +defaultSignificand x = + let T2 m _ = decodeFloat x + d = splat (floatDigits x) + in encodeFloat m (negate d) + +defaultScaleFloat + :: (RealFloat (Vec n a), Exponent (Vec n a) ~ Vec n Int, BitOrMask (EltR (Vec n a)) ~ PrimMask n, VEq n a) + => Exp Int + -> Exp (Vec n a) + -> Exp (Vec n a) +defaultScaleFloat k x = + select (k' ==* 0 ||* isFix) x (encodeFloat m (n + clamp b)) + where + k' = splat k + isFix = x ==* 0 ||* isNaN x ||* isInfinite x + T2 m n = decodeFloat x + T2 l h = floatRange x + d = floatDigits x + b = splat (h - l + 4*d) + -- n+k may overflow, which would lead to incorrect results, hence we clamp + -- the scaling parameter. If (n+k) would be larger than h, (n + clamp b k) + -- must be too, similar for smaller than (l-d). + clamp bd = max (-bd) (min bd k') + +-- Must test for ±0.0 to avoid returning -0.0 in the second component of the +-- pair. Unfortunately the branching costs a lot of performance. +-- +-- Orphaned from RealFrac module +-- +-- defaultProperFraction +-- :: (ToFloating b a, RealFrac a, IsIntegral b, Num b, Floating a) +-- => Exp a +-- -> (Exp b, Exp a) +-- defaultProperFraction x = +-- unlift $ Exp +-- $ Cond (x == 0) (tup2 (0, 0)) +-- (tup2 (n, f)) +-- where +-- n = truncate x +-- f = x - toFloating n + +defaultProperFraction + :: (RealFloat a, FromIntegral (Significand a) b, Integral b) + => Exp a + -> Exp (b, a) +defaultProperFraction x = + cond (n >= 0) + (T2 (fromIntegral m * (2 ^ n)) 0.0) + (T2 (fromIntegral q) (encodeFloat r n)) + where + T2 m n = decodeFloat x + (q, r) = quotRem m (2 ^ (negate n)) + +-- defaultProperFraction' +-- :: (SIMD n a, SIMD n b, RealFloat (Vec n a), Exponent (Vec n a) ~ Vec n Int, FromIntegral (Significand (Vec n a)) (Vec n b), Integral (Vec n b)) +-- => Exp (Vec n a) +-- -> Exp (Vec n b, Vec n a) +-- defaultProperFraction' x = +-- T2 (select p (fromIntegral m * (2 ^ n)) (fromIntegral q)) +-- (select p 0.0 (encodeFloat r n)) +-- where +-- T2 m n = decodeFloat x +-- (q, r) = quotRem m (2 ^ (negate n)) +-- p = n >=* 0 + +-- This is a bit weird because we really want to apply the function late-wise, +-- but there isn't really a way we can do that. Boo. ---TLM 2022-09-20 +-- +defaultProperFraction' + :: (SIMD n a, RealFloat (Vec n a), Exponent (Vec n a) ~ Vec n Int, FromIntegral (Significand (Vec n a)) b, Integral b) + => Exp (Vec n a) + -> Exp (b, Vec n a) +defaultProperFraction' x = + T2 (cond (n >= 0) (fromIntegral m * (2 ^ n)) (fromIntegral q)) + (select (n >=* 0) 0.0 (encodeFloat r n)) + where + T2 m n = decodeFloat x + (q, r) = quotRem m (2 ^ (negate n)) + -- From: ghc/libraries/base/cbits/primFloat.c -- ------------------------------------------ @@ -216,102 +445,128 @@ ieee754 name f x -- * mantissa is non-zero. -- * (don't care about setting of sign bit.) -- -ieee754_f64_is_denormalized :: Exp Word64 -> Exp Bool +ieee754_f128_is_denormalized :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Bool) +ieee754_f128_is_denormalized x = + ieee754_f128_mantissa x ==* 0 &&* + ieee754_f128_exponent x /=* 0 + +ieee754_f64_is_denormalized :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Bool) ieee754_f64_is_denormalized x = - ieee754_f64_mantissa x == 0 && - ieee754_f64_exponent x /= 0 + ieee754_f64_mantissa x ==* 0 &&* + ieee754_f64_exponent x /=* 0 -ieee754_f32_is_denormalized :: Exp Word32 -> Exp Bool +ieee754_f32_is_denormalized :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Bool) ieee754_f32_is_denormalized x = - ieee754_f32_mantissa x == 0 && - ieee754_f32_exponent x /= 0 + ieee754_f32_mantissa x ==* 0 &&* + ieee754_f32_exponent x /=* 0 -ieee754_f16_is_denormalized :: Exp Word16 -> Exp Bool +ieee754_f16_is_denormalized :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Bool) ieee754_f16_is_denormalized x = - ieee754_f16_mantissa x == 0 && - ieee754_f16_exponent x /= 0 + ieee754_f16_mantissa x ==* 0 &&* + ieee754_f16_exponent x /=* 0 -- Negative zero if only the sign bit is set -- -ieee754_f64_is_negative_zero :: Exp Word64 -> Exp Bool +ieee754_f128_is_negative_zero :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Bool) +ieee754_f128_is_negative_zero x = + ieee754_f128_negative x &&* + ieee754_f128_exponent x ==* 0 &&* + ieee754_f128_mantissa x ==* 0 + +ieee754_f64_is_negative_zero :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Bool) ieee754_f64_is_negative_zero x = - ieee754_f64_negative x && - ieee754_f64_exponent x == 0 && - ieee754_f64_mantissa x == 0 + ieee754_f64_negative x &&* + ieee754_f64_exponent x ==* 0 &&* + ieee754_f64_mantissa x ==* 0 -ieee754_f32_is_negative_zero :: Exp Word32 -> Exp Bool +ieee754_f32_is_negative_zero :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Bool) ieee754_f32_is_negative_zero x = - ieee754_f32_negative x && - ieee754_f32_exponent x == 0 && - ieee754_f32_mantissa x == 0 + ieee754_f32_negative x &&* + ieee754_f32_exponent x ==* 0 &&* + ieee754_f32_mantissa x ==* 0 -ieee754_f16_is_negative_zero :: Exp Word16 -> Exp Bool +ieee754_f16_is_negative_zero :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Bool) ieee754_f16_is_negative_zero x = - ieee754_f16_negative x && - ieee754_f16_exponent x == 0 && - ieee754_f16_mantissa x == 0 + ieee754_f16_negative x &&* + ieee754_f16_exponent x ==* 0 &&* + ieee754_f16_mantissa x ==* 0 -- Assume the host processor stores integers and floating point numbers in the -- same endianness (true for modern processors). -- --- To recap, here's the representation of a double precision +-- To recap, here's the representation of a quadruple precision -- IEEE floating point number: -- +-- sign 127 sign bit (0==positive, 1==negative) +-- exponent 126-112 exponent (biased by 16383) +-- fraction 111-0 fraction (bits to right of binary part) +-- +ieee754_f128_mantissa :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Word128) +ieee754_f128_mantissa x = x .&. 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFF + +ieee754_f128_exponent :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Word16) +ieee754_f128_exponent x = fromIntegral (x `unsafeShiftR` 112) .&. 0x7FFF + +ieee754_f128_negative :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Bool) +ieee754_f128_negative x = testBit x 127 + +-- Representation of a double precision IEEE floating point number: +-- -- sign 63 sign bit (0==positive, 1==negative) -- exponent 62-52 exponent (biased by 1023) -- fraction 51-0 fraction (bits to right of binary point) -- -ieee754_f64_mantissa :: Exp Word64 -> Exp Word64 +ieee754_f64_mantissa :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Word64) ieee754_f64_mantissa x = x .&. 0xFFFFFFFFFFFFF -ieee754_f64_exponent :: Exp Word64 -> Exp Word16 +ieee754_f64_exponent :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Word16) ieee754_f64_exponent x = fromIntegral (x `unsafeShiftR` 52) .&. 0x7FF -ieee754_f64_negative :: Exp Word64 -> Exp Bool +ieee754_f64_negative :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Bool) ieee754_f64_negative x = testBit x 63 --- Representation of single precision IEEE floating point number: +-- Representation of a single precision IEEE floating point number: -- -- sign 31 sign bit (0==positive, 1==negative) -- exponent 30-23 exponent (biased by 127) -- fraction 22-0 fraction (bits to right of binary point) -- -ieee754_f32_mantissa :: Exp Word32 -> Exp Word32 +ieee754_f32_mantissa :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Word32) ieee754_f32_mantissa x = x .&. 0x7FFFFF -ieee754_f32_exponent :: Exp Word32 -> Exp Word8 +ieee754_f32_exponent :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Word8) ieee754_f32_exponent x = fromIntegral (x `unsafeShiftR` 23) -ieee754_f32_negative :: Exp Word32 -> Exp Bool +ieee754_f32_negative :: KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Bool) ieee754_f32_negative x = testBit x 31 --- Representation of half precision IEEE floating point number: +-- Representation of a half precision IEEE floating point number: -- -- sign 15 sign bit (0==positive, 1==negative) -- exponent 14-10 exponent (biased by 15) -- fraction 9-0 fraction (bits to right of binary point) -- -ieee754_f16_mantissa :: Exp Word16 -> Exp Word16 +ieee754_f16_mantissa :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Word16) ieee754_f16_mantissa x = x .&. 0x3FF -ieee754_f16_exponent :: Exp Word16 -> Exp Word8 +ieee754_f16_exponent :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Word8) ieee754_f16_exponent x = fromIntegral (x `unsafeShiftR` 10) .&. 0x1F -ieee754_f16_negative :: Exp Word16 -> Exp Bool +ieee754_f16_negative :: KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Bool) ieee754_f16_negative x = testBit x 15 -- reverse engineered following the below -ieee754_f16_decode :: Exp Word16 -> Exp (Int16, Int) +ieee754_f16_decode :: forall n. KnownNat n => Exp (Vec n Word16) -> Exp (Vec n Int16, Vec n Int) ieee754_f16_decode i = let _HHIGHBIT = 0x0400 _HMSBIT = 0x8000 - _HMINEXP = ((_HALF_MIN_EXP) - (_HALF_MANT_DIG) - 1) - _HALF_MANT_DIG = floatDigits (undefined::Exp Half) - (_HALF_MIN_EXP, _HALF_MAX_EXP) = floatRange (undefined::Exp Half) + _HMINEXP = splat ((_HALF_MIN_EXP) - (_HALF_MANT_DIG) - 1) + _HALF_MANT_DIG = floatDigits (undefined::Exp Float16) + T2 _HALF_MIN_EXP _HALF_MAX_EXP = floatRange (undefined::Exp Float16) high1 = fromIntegral i high2 = high1 .&. (_HHIGHBIT - 1) @@ -324,28 +579,38 @@ ieee754_f16_decode i = -- don't add hidden bit to denorms (T2 (high2 .|. _HHIGHBIT) exp1) -- a denorm, normalise the mantissa - (while (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0 ) - (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) + (while (\(T2 h _) -> (h .&. _HHIGHBIT) /= 0) + (\(T2 h e) -> let p = (h .&. _HHIGHBIT) /=* 0 + in T2 (select p (h `unsafeShiftL` 1) h) + (select p (e-1) e)) (T2 high2 exp2)) - high4 = cond (fromIntegral i < (0 :: Exp Int16)) (-high3) high3 + high4 = select (fromIntegral i <* (0 :: Exp (Vec n Int16))) (-high3) high3 + z = high1 .&. complement _HMSBIT ==* 0 in - cond (high1 .&. complement _HMSBIT == 0) - (T2 0 0) - (T2 high4 exp3) + T2 (select z 0 high4) + (select z 0 exp3) -- From: ghc/rts/StgPrimFloat.c -- ---------------------------- +-- +-- The fast-path (no denormalised values) looks good to me, but if any one of +-- the lanes contains a denormalised value then all lanes need to continue +-- looping in predicated style to normalise the mantissa. We do a bit of +-- redundant work here that could be avoided; maybe the codegen will clean that +-- up for us, but if not it shouldn't matter too much anyway (slow path...). +-- -- TLM 2022-09-20. +-- -ieee754_f32_decode :: Exp Word32 -> Exp (Int32, Int) +ieee754_f32_decode :: forall n. KnownNat n => Exp (Vec n Word32) -> Exp (Vec n Int32, Vec n Int) ieee754_f32_decode i = let _FHIGHBIT = 0x00800000 _FMSBIT = 0x80000000 - _FMINEXP = ((_FLT_MIN_EXP) - (_FLT_MANT_DIG) - 1) - _FLT_MANT_DIG = floatDigits (undefined::Exp Float) - (_FLT_MIN_EXP, _FLT_MAX_EXP) = floatRange (undefined::Exp Float) + _FMINEXP = splat ((_FLT_MIN_EXP) - (_FLT_MANT_DIG) - 1) + _FLT_MANT_DIG = floatDigits (undefined::Exp Float32) + T2 _FLT_MIN_EXP _FLT_MAX_EXP = floatRange (undefined::Exp Float32) high1 = fromIntegral i high2 = high1 .&. (_FHIGHBIT - 1) @@ -358,36 +623,38 @@ ieee754_f32_decode i = -- don't add hidden bit to denorms (T2 (high2 .|. _FHIGHBIT) exp1) -- a denorm, normalise the mantissa - (while (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0 ) - (\(T2 h e) -> T2 (h `unsafeShiftL` 1) (e-1)) + (while (\(T2 h _) -> (h .&. _FHIGHBIT) /= 0) + (\(T2 h e) -> let p = (h .&. _FHIGHBIT) /=* 0 + in T2 (select p (h `unsafeShiftL` 1) h) + (select p (e-1) e)) (T2 high2 exp2)) - high4 = cond (fromIntegral i < (0 :: Exp Int32)) (-high3) high3 + high4 = select (fromIntegral i <* (0 :: Exp (Vec n Int32))) (-high3) high3 + z = high1 .&. complement _FMSBIT ==* 0 in - cond (high1 .&. complement _FMSBIT == 0) - (T2 0 0) - (T2 high4 exp3) + T2 (select z 0 high4) + (select z 0 exp3) -ieee754_f64_decode :: Exp Word64 -> Exp (Int64, Int) +ieee754_f64_decode :: KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Int64, Vec n Int) ieee754_f64_decode i = let T4 s h l e = ieee754_f64_decode2 i in T2 (fromIntegral s * (fromIntegral h `unsafeShiftL` 32 .|. fromIntegral l)) e -ieee754_f64_decode2 :: Exp Word64 -> Exp (Int, Word32, Word32, Int) +ieee754_f64_decode2 :: forall n. KnownNat n => Exp (Vec n Word64) -> Exp (Vec n Int64, Vec n Word32, Vec n Word32, Vec n Int) ieee754_f64_decode2 i = let _DHIGHBIT = 0x00100000 _DMSBIT = 0x80000000 - _DMINEXP = ((_DBL_MIN_EXP) - (_DBL_MANT_DIG) - 1) - _DBL_MANT_DIG = floatDigits (undefined::Exp Double) - (_DBL_MIN_EXP, _DBL_MAX_EXP) = floatRange (undefined::Exp Double) + _DMINEXP = splat ((_DBL_MIN_EXP) - (_DBL_MANT_DIG) - 1) + _DBL_MANT_DIG = floatDigits (undefined::Exp Float64) + T2 _DBL_MIN_EXP _DBL_MAX_EXP = floatRange (undefined::Exp Float64) low = fromIntegral i high = fromIntegral (i `unsafeShiftR` 32) iexp = (fromIntegral ((high `unsafeShiftR` 20) .&. 0x7FF) + _DMINEXP) - sign = cond (fromIntegral i < (0 :: Exp Int64)) (-1) 1 + sign = select (fromIntegral i <* (0 :: Exp (Vec n Int64))) (-1) 1 high2 = high .&. (_DHIGHBIT - 1) iexp2 = iexp + 1 @@ -399,13 +666,21 @@ ieee754_f64_decode2 i = -- a denorm, nermalise the mantissa (while (\(T3 h _ _) -> (h .&. _DHIGHBIT) /= 0) (\(T3 h l e) -> - let h1 = h `unsafeShiftL` 1 + let p = (h .&. _DHIGHBIT) /=* 0 + h1 = h `unsafeShiftL` 1 h2 = cond ((l .&. _DMSBIT) /= 0) (h1+1) h1 - in T3 h2 (l `unsafeShiftL` 1) (e-1)) + in T3 (select p h2 h) + (select p (l `unsafeShiftL` 1) l) + (select p (e-1) e)) (T3 high2 low iexp2)) + z = low ==* 0 &&* (high .&. (complement _DMSBIT)) ==* 0 in - cond (low == 0 && (high .&. (complement _DMSBIT)) == 0) - (T4 1 0 0 0) - (T4 sign hi lo ie) + T4 (select z 1 sign) + (select z 0 hi) + (select z 0 lo) + (select z 0 ie) + +ieee754_f128_decode :: KnownNat n => Exp (Vec n Word128) -> Exp (Vec n Int128, Vec n Int) +ieee754_f128_decode = error "TODO: ieee754_f128_decode" diff --git a/src/Data/Array/Accelerate/Classes/RealFloat.hs-boot b/src/Data/Array/Accelerate/Classes/RealFloat.hs-boot deleted file mode 100644 index a4f2878bd..000000000 --- a/src/Data/Array/Accelerate/Classes/RealFloat.hs-boot +++ /dev/null @@ -1,67 +0,0 @@ -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE FlexibleContexts #-} --- | --- Module : Data.Array.Accelerate.Classes.RealFloat --- Copyright : [2019..2020] The Accelerate Team --- License : BSD3 --- --- Maintainer : Trevor L. McDonell --- Stability : experimental --- Portability : non-portable (GHC extensions) --- - -module Data.Array.Accelerate.Classes.RealFloat - where - -import Data.Array.Accelerate.Smart -import Data.Array.Accelerate.Type - -import Data.Array.Accelerate.Classes.Floating -import Data.Array.Accelerate.Classes.FromIntegral -import {-# SOURCE #-} Data.Array.Accelerate.Classes.RealFrac - -import Prelude ( Bool ) -import qualified Prelude as P - - -class (RealFrac a, Floating a) => RealFloat a where - floatRadix :: Exp a -> Exp Int64 -- Integer - floatDigits :: Exp a -> Exp Int - floatRange :: Exp a -> (Exp Int, Exp Int) - decodeFloat :: Exp a -> (Exp Int64, Exp Int) -- Integer - encodeFloat :: Exp Int64 -> Exp Int -> Exp a -- Integer - exponent :: Exp a -> Exp Int - significand :: Exp a -> Exp a - scaleFloat :: Exp Int -> Exp a -> Exp a - isNaN :: Exp a -> Exp Bool - isInfinite :: Exp a -> Exp Bool - isDenormalized :: Exp a -> Exp Bool - isNegativeZero :: Exp a -> Exp Bool - isIEEE :: Exp a -> Exp Bool - atan2 :: Exp a -> Exp a -> Exp a - - exponent = P.undefined - significand = P.undefined - scaleFloat = P.undefined - - default floatRadix :: P.RealFloat a => Exp a -> Exp Int64 - floatRadix _ = P.undefined - - default floatDigits :: P.RealFloat a => Exp a -> Exp Int - floatDigits _ = P.undefined - - default floatRange :: P.RealFloat a => Exp a -> (Exp Int, Exp Int) - floatRange _ = P.undefined - - default encodeFloat :: (FromIntegral Int a, FromIntegral Int64 a) => Exp Int64 -> Exp Int -> Exp a - encodeFloat _ _ = P.undefined - - default isIEEE :: P.RealFloat a => Exp a -> Exp Bool - isIEEE _ = P.undefined - -instance RealFloat Half -instance RealFloat Float -instance RealFloat Double -instance RealFloat CFloat -instance RealFloat CDouble - diff --git a/src/Data/Array/Accelerate/Classes/RealFrac.hs b/src/Data/Array/Accelerate/Classes/RealFrac.hs index 4cec0fde9..6c7a285a5 100644 --- a/src/Data/Array/Accelerate/Classes/RealFrac.hs +++ b/src/Data/Array/Accelerate/Classes/RealFrac.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -23,47 +24,50 @@ module Data.Array.Accelerate.Classes.RealFrac ( ) where -import Data.Array.Accelerate.Language ( (^), cond, even ) -import Data.Array.Accelerate.Lift ( unlift ) -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Language ( cond, even ) +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Eq -import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.Fractional import Data.Array.Accelerate.Classes.FromIntegral import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num +import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.ToFloating -import {-# SOURCE #-} Data.Array.Accelerate.Classes.RealFloat -- defaultProperFraction import Data.Maybe +import Data.Kind +import Prelude ( ($), String, error, otherwise, unlines ) import Text.Printf -import Prelude ( ($), String, error, unlines, otherwise ) import qualified Prelude as P -- | Generalisation of 'P.div' to any instance of 'RealFrac' -- -div' :: (RealFrac a, FromIntegral Int64 b, Integral b) => Exp a -> Exp a -> Exp b +div' :: (RealFrac a, FromIntegral (Significand a) b, Integral b) => Exp a -> Exp a -> Exp b div' n d = floor (n / d) -- | Generalisation of 'P.mod' to any instance of 'RealFrac' -- -mod' :: (Floating a, RealFrac a, ToFloating Int64 a) => Exp a -> Exp a -> Exp a +mod' :: forall a. (Floating a, RealFrac a, Integral (Significand a), ToFloating (Significand a) a, FromIntegral (Significand a) (Significand a)) + => Exp a + -> Exp a + -> Exp a mod' n d = n - (toFloating f) * d where - f :: Exp Int64 + f :: Exp (Significand a) f = div' n d -- | Generalisation of 'P.divMod' to any instance of 'RealFrac' -- divMod' - :: (Floating a, RealFrac a, Integral b, FromIntegral Int64 b, ToFloating b a) + :: (Floating a, RealFrac a, Integral b, FromIntegral (Significand a) b, ToFloating b a) => Exp a -> Exp a -> (Exp b, Exp a) @@ -74,7 +78,15 @@ divMod' n d = (f, n - (toFloating f) * d) -- | Extracting components of fractions. -- -class (Ord a, Fractional a) => RealFrac a where +class (Ord a, Fractional a, Integral (Significand a)) => RealFrac a where + -- | The significand (also known as the mantissa) is the part of a number in + -- floating point representation consisting of the significant digits. + -- Generally speaking, this is the integral part of a fractional number. + -- + type Significand a :: Type + + {-# MINIMAL properFraction #-} + -- | The function 'properFraction' takes a real fractional number @x@ and -- returns a pair @(n,f)@ such that @x = n+f@, and: -- @@ -85,7 +97,7 @@ class (Ord a, Fractional a) => RealFrac a where -- -- The default definitions of the 'ceiling', 'floor', 'truncate' -- and 'round' functions are in terms of 'properFraction'. - properFraction :: (Integral b, FromIntegral Int64 b) => Exp a -> (Exp b, Exp a) + properFraction :: (FromIntegral (Significand a) b, Integral b) => Exp a -> Exp (b, a) -- The function 'splitFraction' takes a real fractional number @x@ and -- returns a pair @(n,f)@ such that @x = n+f@, and: @@ -104,109 +116,58 @@ class (Ord a, Fractional a) => RealFrac a where -- splitFraction / fraction are from numeric-prelude Algebra.RealRing -- | @truncate x@ returns the integer nearest @x@ between zero and @x@ - truncate :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b + truncate :: (Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b truncate = defaultTruncate -- | @'round' x@ returns the nearest integer to @x@; the even integer if @x@ -- is equidistant between two integers - round :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - round = defaultRound + round :: (Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b + round = defaultRound -- | @'ceiling' x@ returns the least integer not less than @x@ - ceiling :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - ceiling = defaultCeiling + ceiling :: (Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b + ceiling = defaultCeiling -- | @'floor' x@ returns the greatest integer not greater than @x@ - floor :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - floor = defaultFloor - -instance RealFrac Half where - properFraction = defaultProperFraction - -instance RealFrac Float where - properFraction = defaultProperFraction - -instance RealFrac Double where - properFraction = defaultProperFraction - -instance RealFrac CFloat where - properFraction = defaultProperFraction - truncate = defaultTruncate - round = defaultRound - ceiling = defaultCeiling - floor = defaultFloor - -instance RealFrac CDouble where - properFraction = defaultProperFraction - truncate = defaultTruncate - round = defaultRound - ceiling = defaultCeiling - floor = defaultFloor - + floor :: (Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b + floor = defaultFloor --- Must test for ±0.0 to avoid returning -0.0 in the second component of the --- pair. Unfortunately the branching costs a lot of performance. --- --- defaultProperFraction --- :: (ToFloating b a, RealFrac a, IsIntegral b, Num b, Floating a) --- => Exp a --- -> (Exp b, Exp a) --- defaultProperFraction x = --- unlift $ Exp --- $ Cond (x == 0) (tup2 (0, 0)) --- (tup2 (n, f)) --- where --- n = truncate x --- f = x - toFloating n - -defaultProperFraction - :: (RealFloat a, FromIntegral Int64 b, Integral b) - => Exp a - -> (Exp b, Exp a) -defaultProperFraction x - = unlift - $ cond (n >= 0) - (T2 (fromIntegral m * (2 ^ n)) 0.0) - (T2 (fromIntegral q) (encodeFloat r n)) - where - (m, n) = decodeFloat x - (q, r) = quotRem m (2 ^ (negate n)) -defaultTruncate :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b +defaultTruncate :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b defaultTruncate x - | Just IsFloatingDict <- isFloating @a - , Just IsIntegralDict <- isIntegral @b - = mkTruncate x + | Just FloatingDict <- floatingDict @a + , Just IntegralDict <- integralDict @b + = mkPrimUnary (PrimTruncate floatingType integralType) x -- | otherwise - = let (n, _) = properFraction x in n + = let T2 n _ = properFraction x in n -defaultCeiling :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b +defaultCeiling :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b defaultCeiling x - | Just IsFloatingDict <- isFloating @a - , Just IsIntegralDict <- isIntegral @b - = mkCeiling x + | Just FloatingDict <- floatingDict @a + , Just IntegralDict <- integralDict @b + = mkPrimUnary (PrimCeiling floatingType integralType) x -- | otherwise - = let (n, r) = properFraction x in cond (r > 0) (n+1) n + = let T2 n r = properFraction x in cond (r > 0) (n+1) n -defaultFloor :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b +defaultFloor :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b defaultFloor x - | Just IsFloatingDict <- isFloating @a - , Just IsIntegralDict <- isIntegral @b - = mkFloor x + | Just FloatingDict <- floatingDict @a + , Just IntegralDict <- integralDict @b + = mkPrimUnary (PrimFloor floatingType integralType) x -- | otherwise - = let (n, r) = properFraction x in cond (r < 0) (n-1) n + = let T2 n r = properFraction x in cond (r < 0) (n-1) n -defaultRound :: forall a b. (RealFrac a, Integral b, FromIntegral Int64 b) => Exp a -> Exp b +defaultRound :: forall a b. (RealFrac a, Integral b, FromIntegral (Significand a) b) => Exp a -> Exp b defaultRound x - | Just IsFloatingDict <- isFloating @a - , Just IsIntegralDict <- isIntegral @b - = mkRound x + | Just FloatingDict <- floatingDict @a + , Just IntegralDict <- integralDict @b + = mkPrimUnary (PrimRound floatingType integralType) x -- | otherwise - = let (n, r) = properFraction x + = let T2 n r = properFraction x m = cond (r < 0.0) (n-1) (n+1) half_down = abs r - 0.5 p = compare half_down 0.0 @@ -216,46 +177,81 @@ defaultRound x {- otherwise -} m -data IsFloatingDict a where - IsFloatingDict :: IsFloating a => IsFloatingDict a +data FloatingDict a where + FloatingDict :: IsFloating a => FloatingDict a -data IsIntegralDict a where - IsIntegralDict :: IsIntegral a => IsIntegralDict a +data IntegralDict a where + IntegralDict :: IsIntegral a => IntegralDict a -isFloating :: forall a. Elt a => Maybe (IsFloatingDict (EltR a)) -isFloating - | TupRsingle t <- eltR @a - , SingleScalarType s <- t - , NumSingleType n <- s - , FloatingNumType f <- n - = case f of - TypeHalf{} -> Just IsFloatingDict - TypeFloat{} -> Just IsFloatingDict - TypeDouble{} -> Just IsFloatingDict - -- - | otherwise - = Nothing - -isIntegral :: forall a. Elt a => Maybe (IsIntegralDict (EltR a)) -isIntegral - | TupRsingle t <- eltR @a - , SingleScalarType s <- t - , NumSingleType n <- s - , IntegralNumType i <- n - = case i of - TypeInt{} -> Just IsIntegralDict - TypeInt8{} -> Just IsIntegralDict - TypeInt16{} -> Just IsIntegralDict - TypeInt32{} -> Just IsIntegralDict - TypeInt64{} -> Just IsIntegralDict - TypeWord{} -> Just IsIntegralDict - TypeWord8{} -> Just IsIntegralDict - TypeWord16{} -> Just IsIntegralDict - TypeWord32{} -> Just IsIntegralDict - TypeWord64{} -> Just IsIntegralDict - -- - | otherwise - = Nothing +floatingDict :: forall a. Elt a => Maybe (FloatingDict (EltR a)) +floatingDict = go (eltR @a) + where + go :: TypeR t -> Maybe (FloatingDict t) + go (TupRsingle t) = scalar t + go _ = Nothing + + scalar :: ScalarType t -> Maybe (FloatingDict t) + scalar (NumScalarType t) = num t + scalar _ = Nothing + + num :: NumType t -> Maybe (FloatingDict t) + num (FloatingNumType t) = floating t + num _ = Nothing + + floating :: forall t. FloatingType t -> Maybe (FloatingDict t) + floating (SingleFloatingType t) = + case t of + TypeFloat16 -> Just FloatingDict + TypeFloat32 -> Just FloatingDict + TypeFloat64 -> Just FloatingDict + TypeFloat128 -> Just FloatingDict + floating (VectorFloatingType _ t) = + case t of + TypeFloat16 -> Just FloatingDict + TypeFloat32 -> Just FloatingDict + TypeFloat64 -> Just FloatingDict + TypeFloat128 -> Just FloatingDict + +integralDict :: forall a. Elt a => Maybe (IntegralDict (EltR a)) +integralDict = go (eltR @a) + where + go :: TypeR t -> Maybe (IntegralDict t) + go (TupRsingle t) = scalar t + go _ = Nothing + + scalar :: ScalarType t -> Maybe (IntegralDict t) + scalar (NumScalarType t) = num t + scalar _ = Nothing + + num :: NumType t -> Maybe (IntegralDict t) + num (IntegralNumType t) = integral t + num _ = Nothing + + integral :: forall t. IntegralType t -> Maybe (IntegralDict t) + integral (SingleIntegralType t) = + case t of + TypeInt8 -> Just IntegralDict + TypeInt16 -> Just IntegralDict + TypeInt32 -> Just IntegralDict + TypeInt64 -> Just IntegralDict + TypeInt128 -> Just IntegralDict + TypeWord8 -> Just IntegralDict + TypeWord16 -> Just IntegralDict + TypeWord32 -> Just IntegralDict + TypeWord64 -> Just IntegralDict + TypeWord128 -> Just IntegralDict + integral (VectorIntegralType _ t) = + case t of + TypeInt8 -> Just IntegralDict + TypeInt16 -> Just IntegralDict + TypeInt32 -> Just IntegralDict + TypeInt64 -> Just IntegralDict + TypeInt128 -> Just IntegralDict + TypeWord8 -> Just IntegralDict + TypeWord16 -> Just IntegralDict + TypeWord32 -> Just IntegralDict + TypeWord64 -> Just IntegralDict + TypeWord128 -> Just IntegralDict -- To satisfy superclass constraints @@ -276,3 +272,6 @@ preludeError x , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] +-- Instances declared in Data.Array.Accelerate.Classes.RealFloat to avoid +-- recursive modules + diff --git a/src/Data/Array/Accelerate/Classes/RealFrac.hs-boot b/src/Data/Array/Accelerate/Classes/RealFrac.hs-boot deleted file mode 100644 index 1bb85ec04..000000000 --- a/src/Data/Array/Accelerate/Classes/RealFrac.hs-boot +++ /dev/null @@ -1,43 +0,0 @@ -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE NoImplicitPrelude #-} --- | --- Module : Data.Array.Accelerate.Classes.RealFrac --- Copyright : [2019..2020] The Accelerate Team --- License : BSD3 --- --- Maintainer : Trevor L. McDonell --- Stability : experimental --- Portability : non-portable (GHC extensions) --- - -module Data.Array.Accelerate.Classes.RealFrac - where - -import Data.Array.Accelerate.Smart -import Data.Array.Accelerate.Type - -import Data.Array.Accelerate.Classes.Fractional -import Data.Array.Accelerate.Classes.FromIntegral -import Data.Array.Accelerate.Classes.Integral -import Data.Array.Accelerate.Classes.Ord - -import qualified Prelude as P - -class (Ord a, Fractional a) => RealFrac a where - properFraction :: (Integral b, FromIntegral Int64 b) => Exp a -> (Exp b, Exp a) - truncate :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - round :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - ceiling :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - floor :: (Integral b, FromIntegral Int64 b) => Exp a -> Exp b - - truncate = P.undefined - round = P.undefined - ceiling = P.undefined - floor = P.undefined - -instance RealFrac Half -instance RealFrac Float -instance RealFrac Double -instance RealFrac CFloat -instance RealFrac CDouble - diff --git a/src/Data/Array/Accelerate/Classes/ToFloating.hs b/src/Data/Array/Accelerate/Classes/ToFloating.hs index c3f4545c6..05adf84e6 100644 --- a/src/Data/Array/Accelerate/Classes/ToFloating.hs +++ b/src/Data/Array/Accelerate/Classes/ToFloating.hs @@ -19,7 +19,10 @@ module Data.Array.Accelerate.Classes.ToFloating ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Floating @@ -39,42 +42,49 @@ class ToFloating a b where -- | General coercion to floating types toFloating :: (Num a, Floating b) => Exp a -> Exp b --- instance (Elt a, Elt b, IsNum a, IsFloating b) => ToFloating a b where --- toFloating = mkToFloating +mkToFloating :: (IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b +mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType + -- Generate standard instances explicitly. See also: 'FromIntegral'. -- -$(runQ $ do - let - -- Get all the types that our dictionaries reify - digItOut :: Name -> Q [Name] - digItOut name = do - TyConI (DataD _ _ _ _ cons _) <- reify name - let - -- This is what a constructor such as IntegralNumType will be reified - -- as prior to GHC 8.4... - dig (NormalC _ [(_, AppT (ConT n) (VarT _))]) = digItOut n - -- ...but this is what IntegralNumType will be reified as on GHC 8.4 - -- and later, after the changes described in - -- https://ghc.haskell.org/trac/ghc/wiki/Migration/8.4#TemplateHaskellreificationchangesforGADTs - dig (ForallC _ _ (GadtC _ [(_, AppT (ConT n) (VarT _))] _)) = digItOut n - dig (GadtC _ _ (AppT (ConT _) (ConT n))) = return [n] - dig _ = error "Unexpected case generating ToFloating instances" - -- - concat `fmap` mapM dig cons - - thToFloating :: Name -> Name -> Q Dec - thToFloating a b = - let - ty = AppT (AppT (ConT (mkName "ToFloating")) (ConT a)) (ConT b) - dec = ValD (VarP (mkName "toFloating")) (NormalB (VarE (mkName f))) [] - f | a == b = "id" - | otherwise = "mkToFloating" - in - instanceD (return []) (return ty) [return dec] - -- - as <- digItOut ''NumType - bs <- digItOut ''FloatingType - sequence [ thToFloating a b | a <- as, b <- bs ] - ) +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + thToFloating :: Name -> Name -> Q [Dec] + thToFloating a b = + [d| instance ToFloating $(conT a) $(conT b) where + toFloating = $(varE $ if a == b then 'id else 'mkToFloating) + + instance KnownNat n => ToFloating (Vec n $(conT a)) (Vec n $(conT b)) where + toFloating = $(varE $ if a == b then 'id else 'mkToFloating) + |] + in + concat <$> sequence [ thToFloating from to | from <- numTypes, to <- floatingTypes ] diff --git a/src/Data/Array/Accelerate/Classes/VEq.hs b/src/Data/Array/Accelerate/Classes/VEq.hs new file mode 100644 index 000000000..1a248de22 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VEq.hs @@ -0,0 +1,232 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VEq +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VEq ( + + VEq(..), + (&&*), + (||*), + vnot, + vand, + vor, + +) where + +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Classes.Num +import Data.Array.Accelerate.Representation.Tag +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type + +import qualified Data.Primitive.Bit as Prim + +import Language.Haskell.TH.Extra hiding ( Type, Exp ) + +import Prelude hiding ( Eq(..) ) + + +-- | Vectorised conjunction: Element-wise returns true if both arguments in +-- the corresponding lane are True. This is a strict vectorised version of +-- '(Data.Array.Accelerate.&&)' that always evaluates both arguments. +-- +-- @since 1.4.0.0 +-- +infixr 3 &&* +(&&*) :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) -> Exp (Vec n Bool) +(&&*) = mkPrimBinary $ PrimLAnd bitType + +-- | Vectorised disjunction: Element-wise returns true if either argument +-- in the corresponding lane is true. This is a strict vectorised version +-- of '(Data.Array.Accelerate.||)' that always evaluates both arguments. +-- +-- @since 1.4.0.0 +-- +infixr 2 ||* +(||*) :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) -> Exp (Vec n Bool) +(||*) = mkPrimBinary $ PrimLOr bitType + +-- | Vectorised logical negation +-- +-- @since 1.4.0.0 +-- +vnot :: KnownNat n => Exp (Vec n Bool) -> Exp (Vec n Bool) +vnot = mkPrimUnary $ PrimLNot bitType + +-- | Return 'True' if all lanes of the vector are 'True' +-- +-- @since 1.4.0.0 +-- +vand :: KnownNat n => Exp (Vec n Bool) -> Exp Bool +vand = mkPrimUnary $ PrimVLAnd bitType + +-- | Return 'True' if any lane of the vector is 'True' +-- +-- @since 1.4.0.0 +-- +vor :: KnownNat n => Exp (Vec n Bool) -> Exp Bool +vor = mkPrimUnary $ PrimVLOr bitType + + +infix 4 ==* +infix 4 /=* + +-- | The 'VEq' class defines lane-wise equality '(==*)' and inequality +-- '(/=*)' for Accelerate vector expressions. +-- +class SIMD n a => VEq n a where + (==*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (/=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + {-# MINIMAL (==*) | (/=*) #-} + x ==* y = vnot (x /=* y) + x /=* y = vnot (x ==* y) + +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + nonNumTypes :: [Name] + nonNumTypes = + [ ''Char + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + mkPrim :: Name -> Q [Dec] + mkPrim name = + [d| instance KnownNat n => VEq n $(conT name) where + (==*) = mkPrimBinary $ PrimEq scalarType + (/=*) = mkPrimBinary $ PrimNEq scalarType + |] + + mkTup :: Word8 -> Q Dec + mkTup n = do + w <- newName "w" + x <- newName "x" + y <- newName "y" + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = (++) <$> mapM (appT [t| Eq |]) ts + <*> mapM (appT [t| SIMD $(varT w) |]) ts + + cmp f = [| pack (zipWith $f (unpack $(varE x)) (unpack $(varE y))) |] + -- + instanceD ctx [t| VEq $(varT w) $res |] + [ funD (mkName "==*") [ clause [varP x, varP y] (normalB (cmp [| (==) |])) [] ] + , funD (mkName "/=*") [ clause [varP x, varP y] (normalB (cmp [| (/=) |])) [] ] + ] + -- + ps <- concat <$> mapM mkPrim (numTypes ++ nonNumTypes) + ts <- mapM mkTup [2..16] + return (ps ++ ts) + +vtrue, vfalse :: KnownNat n => Exp (Vec n Bool) +vtrue = constant (Vec (Prim.unMask Prim.ones)) +vfalse = constant (Vec (Prim.unMask Prim.zeros)) + +instance KnownNat n => VEq n () where + _ ==* _ = vtrue + _ /=* _ = vfalse + +instance KnownNat n => VEq n Z where + _ ==* _ = vtrue + _ /=* _ = vfalse + +instance KnownNat n => VEq n Bool where + (==*) = mkPrimBinary $ PrimEq scalarType + (/=*) = mkPrimBinary $ PrimNEq scalarType + +{-- + (==*) = + let n = natVal' (proxy# :: Proxy# n) + -- + cmp :: forall t. (Elt t, IsIntegral (EltR t)) + => Exp (Vec n Bool) + -> Exp (Vec n Bool) + -> Exp (Vec n Bool) + cmp x y = + let x' = mkPrimUnary (PrimFromBool bitType integralType) x :: Exp t + y' = mkPrimUnary (PrimFromBool bitType integralType) y + in + mkPrimUnary (PrimToBool integralType bitType) (mkBAnd x' y') + in + if n <= 8 then cmp @Word8 else + if n <= 16 then cmp @Word16 else + if n <= 32 then cmp @Word32 else + if n <= 64 then cmp @Word64 else + if n <= 128 then cmp @Word128 else + internalError "Can not handle Vec types with more than 128 lanes" + + (/=*) = + let n = natVal' (proxy# :: Proxy# n) + -- + cmp :: forall t. (Elt t, IsIntegral (EltR t)) + => Exp (Vec n Bool) + -> Exp (Vec n Bool) + -> Exp (Vec n Bool) + cmp x y = + let x' = mkPrimUnary (PrimFromBool bitType integralType) x :: Exp t + y' = mkPrimUnary (PrimFromBool bitType integralType) y + in + mkPrimUnary (PrimToBool integralType bitType) (mkBXor x' y') + in + if n <= 8 then cmp @Word8 else + if n <= 16 then cmp @Word16 else + if n <= 32 then cmp @Word32 else + if n <= 64 then cmp @Word64 else + if n <= 128 then cmp @Word128 else + internalError "Can not handle SIMD vector types with more than 128 lanes" +--} + +instance (Eq sh, SIMD n sh) => VEq n (sh :. Int) where + x ==* y = pack (zipWith (==) (unpack x) (unpack y)) + x /=* y = pack (zipWith (/=) (unpack x) (unpack y)) + +instance KnownNat n => VEq n Ordering where + x ==* y = mkCoerce x ==* (mkCoerce y :: Exp (Vec n TAG)) + x /=* y = mkCoerce x /=* (mkCoerce y :: Exp (Vec n TAG)) + diff --git a/src/Data/Array/Accelerate/Classes/VEq.hs-boot b/src/Data/Array/Accelerate/Classes/VEq.hs-boot new file mode 100644 index 000000000..4b99452f4 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VEq.hs-boot @@ -0,0 +1,27 @@ +{-# LANGUAGE MultiParamTypeClasses #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VEq +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VEq + where + +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec + +class SIMD n a => VEq n a where + (==*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (/=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + {-# MINIMAL (==*) | (/=*) #-} + (==*) = undefined + (/=*) = undefined + +vand :: KnownNat n => Exp (Vec n Bool) -> Exp Bool +vor :: KnownNat n => Exp (Vec n Bool) -> Exp Bool + diff --git a/src/Data/Array/Accelerate/Classes/VNum.hs b/src/Data/Array/Accelerate/Classes/VNum.hs new file mode 100644 index 000000000..c1964b092 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VNum.hs @@ -0,0 +1,81 @@ +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TemplateHaskell #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VNum +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VNum ( + + VNum(..), + +) where + +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Smart + +import Language.Haskell.TH.Extra hiding ( Type, Exp ) + + +-- | The 'VNum' class defines numeric operations over SIMD vectors. +-- +-- @since 1.4.0.0 +-- +class SIMD n a => VNum n a where + -- | Horizontal reduction of a vector with addition. This operation is not + -- guaranteed to preserve the associativity of an equivalent scalarised + -- counterpart. + vadd :: Exp (Vec n a) -> Exp a + + -- | Horizontal reduction of a vector with multiplication. This operation is + -- not guaranteed to preserve the associativity of an equivalent scalarised + -- counterpart. + vmul :: Exp (Vec n a) -> Exp a + + +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + thVNum :: Name -> Q [Dec] + thVNum name = + [d| instance KnownNat n => VNum n $(conT name) where + vadd = mkPrimUnary $ PrimVAdd numType + vmul = mkPrimUnary $ PrimVMul numType + |] + in + concat <$> mapM thVNum numTypes + diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs b/src/Data/Array/Accelerate/Classes/VOrd.hs new file mode 100644 index 000000000..bd4971eff --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs @@ -0,0 +1,186 @@ +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeOperators #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VOrd +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VOrd ( + + VOrd(..), + +) where + +import Data.Array.Accelerate.AST ( PrimFun(..) ) +import Data.Array.Accelerate.Classes.Ord +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Representation.Tag +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type + +import qualified Data.Primitive.Vec as Prim +import qualified Data.Primitive.Bit as Prim + +import Language.Haskell.TH.Extra hiding ( Type, Exp ) + +import Prelude hiding ( Ord(..), Ordering(..), (<*) ) +import qualified Prelude as P + + +infix 4 <* +infix 4 >* +infix 4 <=* +infix 4 >=* + +-- | The 'VOrd' class defines lane-wise comparisons for totally ordered +-- data types. +-- +-- @since 1.4.0.0 +-- +class VEq n a => VOrd n a where + {-# MINIMAL (<=*) | vcompare #-} + (<*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (>*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (<=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (>=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + vmin :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) + vmax :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) + vminimum :: Exp (Vec n a) -> Exp a + vmaximum :: Exp (Vec n a) -> Exp a + vcompare :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Ordering) + + x <* y = select (vcompare x y ==* vlt) vtrue vfalse + x <=* y = select (vcompare x y ==* vgt) vfalse vtrue + x >* y = select (vcompare x y ==* vgt) vtrue vfalse + x >=* y = select (vcompare x y ==* vlt) vfalse vtrue + + vmin x y = select (x <=* y) x y + vmax x y = select (x <=* y) y x + + default vminimum :: Ord a => Exp (Vec n a) -> Exp a + default vmaximum :: Ord a => Exp (Vec n a) -> Exp a + vminimum x = P.minimum (unpack x) + vmaximum x = P.maximum (unpack x) + + vcompare x y + = select (x ==* y) veq + $ select (x <=* y) vlt vgt + +vlt, veq, vgt :: KnownNat n => Exp (Vec n Ordering) +vlt = constant (Vec (let (tag,()) = fromElt P.LT in Prim.splat tag, ())) +veq = constant (Vec (let (tag,()) = fromElt P.EQ in Prim.splat tag, ())) +vgt = constant (Vec (let (tag,()) = fromElt P.GT in Prim.splat tag, ())) + +vtrue, vfalse :: KnownNat n => Exp (Vec n Bool) +vtrue = constant (Vec (Prim.unMask Prim.ones)) +vfalse = constant (Vec (Prim.unMask Prim.zeros)) + +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + nonNumTypes :: [Name] + nonNumTypes = + [ ''Char + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + mkPrim :: Name -> Q [Dec] + mkPrim name = + [d| instance KnownNat n => VOrd n $(conT name) where + (<*) = mkPrimBinary $ PrimLt scalarType + (>*) = mkPrimBinary $ PrimGt scalarType + (<=*) = mkPrimBinary $ PrimLtEq scalarType + (>=*) = mkPrimBinary $ PrimGtEq scalarType + vmin = mkPrimBinary $ PrimMin scalarType + vmax = mkPrimBinary $ PrimMax scalarType + vminimum = mkPrimUnary $ PrimVMin scalarType + vmaximum = mkPrimUnary $ PrimVMax scalarType + |] + + mkTup :: Word8 -> Q Dec + mkTup n = do + w <- newName "w" + x <- newName "x" + y <- newName "y" + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = (++) <$> mapM (appT [t| Ord |]) ts + <*> mapM (appT [t| SIMD $(varT w) |]) ts + cmp f = [| pack (zipWith $f (unpack $(varE x)) (unpack $(varE y))) |] + -- + instanceD ctx [t| VOrd $(varT w) $res |] + [ funD (mkName "<*") [ clause [varP x, varP y] (normalB (cmp [| (<) |])) [] ] + , funD (mkName ">*") [ clause [varP x, varP y] (normalB (cmp [| (>) |])) [] ] + , funD (mkName "<=*") [ clause [varP x, varP y] (normalB (cmp [| (<=) |])) [] ] + , funD (mkName ">=*") [ clause [varP x, varP y] (normalB (cmp [| (>=) |])) [] ] + ] + -- + ps <- concat <$> mapM mkPrim (numTypes ++ nonNumTypes) + ts <- mapM mkTup [2..16] + return (ps ++ ts) + +instance KnownNat n => VOrd n () where + (<*) _ _ = vfalse + (>*) _ _ = vfalse + (<=*) _ _ = vtrue + (>=*) _ _ = vtrue + vcompare _ _ = veq + +instance KnownNat n => VOrd n Z where + (<*) _ _ = vfalse + (>*) _ _ = vfalse + (<=*) _ _ = vtrue + (>=*) _ _ = vtrue + vcompare _ _ = veq + +instance KnownNat n => VOrd n Ordering where + x <* y = mkCoerce x <* (mkCoerce y :: Exp (Vec n TAG)) + x >* y = mkCoerce x >* (mkCoerce y :: Exp (Vec n TAG)) + x <=* y = mkCoerce x <=* (mkCoerce y :: Exp (Vec n TAG)) + x >=* y = mkCoerce x >=* (mkCoerce y :: Exp (Vec n TAG)) + +instance (Ord sh, VOrd n sh) => VOrd n (sh :. Int) where + x <* y = pack (zipWith (<) (unpack x) (unpack y)) + x >* y = pack (zipWith (>) (unpack x) (unpack y)) + x <=* y = pack (zipWith (<=) (unpack x) (unpack y)) + x >=* y = pack (zipWith (>=) (unpack x) (unpack y)) + diff --git a/src/Data/Array/Accelerate/Classes/VOrd.hs-boot b/src/Data/Array/Accelerate/Classes/VOrd.hs-boot new file mode 100644 index 000000000..f379e784b --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/VOrd.hs-boot @@ -0,0 +1,53 @@ +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} +-- | +-- Module : Data.Array.Accelerate.Classes.VOrd +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Classes.VOrd ( + + VOrd(..), + +) where + +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Vec +import {-# SOURCE #-} Data.Array.Accelerate.Classes.Ord + +import Prelude hiding ( Ord(..), Ordering(..), (<*) ) + + +class VEq n a => VOrd n a where + {-# MINIMAL (<=*) | vcompare #-} + (<*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (>*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (<=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + (>=*) :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Bool) + vmin :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) + vmax :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n a) + vminimum :: Exp (Vec n a) -> Exp a + vmaximum :: Exp (Vec n a) -> Exp a + vcompare :: Exp (Vec n a) -> Exp (Vec n a) -> Exp (Vec n Ordering) + + (<*) = undefined + (<=*) = undefined + (>*) = undefined + (>=*) = undefined + + vmin = undefined + vmax = undefined + + default vminimum :: Ord a => Exp (Vec n a) -> Exp a + default vmaximum :: Ord a => Exp (Vec n a) -> Exp a + vminimum = undefined + vmaximum = undefined + + vcompare = undefined + diff --git a/src/Data/Array/Accelerate/Data/Bits.hs b/src/Data/Array/Accelerate/Data/Bits.hs index 9696e3c3e..491a4e159 100644 --- a/src/Data/Array/Accelerate/Data/Bits.hs +++ b/src/Data/Array/Accelerate/Data/Bits.hs @@ -1,9 +1,15 @@ {-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Data.Bits @@ -24,21 +30,32 @@ module Data.Array.Accelerate.Data.Bits ( ) where -import Data.Array.Accelerate.Array.Data +import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) import Data.Array.Accelerate.Language import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Classes.FromIntegral +import Data.Array.Accelerate.Classes.Integral ( ) +import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord -import Data.Array.Accelerate.Classes.Integral () - -import Prelude ( (.), ($), undefined, otherwise ) +import Data.Array.Accelerate.Classes.VEq +import Data.Array.Accelerate.Classes.VOrd + +import Data.Kind +import Data.Type.Equality +import Language.Haskell.TH hiding ( Exp, Type ) +import Prelude ( ($), (<$>), undefined, otherwise, concat, mapM, toInteger ) +import qualified Prelude as P import qualified Data.Bits as B +import GHC.Prim ( Proxy#, proxy# ) +import GHC.TypeLits -infixl 8 `shift`, `rotate`, `shiftL`, `shiftR`, `rotateL`, `rotateR` +infixl 8 `shiftL`, `shiftR`, `rotateL`, `rotateR` infixl 7 .&. infixl 6 `xor` infixl 5 .|. @@ -49,10 +66,8 @@ infixl 5 .|. -- significant bit. -- class Eq a => Bits a where - {-# MINIMAL (.&.), (.|.), xor, complement, - (shift | (shiftL, shiftR)), - (rotate | (rotateL, rotateR)), - isSigned, testBit, bit, popCount #-} + type Bools a :: Type + {-# MINIMAL (.&.), (.|.), xor, complement, shiftL, shiftR, rotateL, rotateR, isSigned, bit, testBit, popCount #-} -- | Bitwise "and" (.&.) :: Exp a -> Exp a -> Exp a @@ -66,58 +81,40 @@ class Eq a => Bits a where -- | Reverse all bits in the argument complement :: Exp a -> Exp a - -- | @'shift' x i@ shifts @x@ left by @i@ bits if @i@ is positive, or right by - -- @-i@ bits otherwise. Right shifts perform sign extension on signed number - -- types; i.e. they fill the top bits with 1 if the @x@ is negative and with - -- 0 otherwise. - shift :: Exp a -> Exp Int -> Exp a - shift x i - = cond (i < 0) (x `shiftR` (-i)) - $ cond (i > 0) (x `shiftL` i) - $ x - - -- | @'rotate' x i@ rotates @x@ left by @i@ bits if @i@ is positive, or right - -- by @-i@ bits otherwise. - rotate :: Exp a -> Exp Int -> Exp a - rotate x i - = cond (i < 0) (x `rotateR` (-i)) - $ cond (i > 0) (x `rotateL` i) - $ x - -- | The value with all bits unset zeroBits :: Exp a + default zeroBits :: Num a => Exp a zeroBits = clearBit (bit 0) 0 -- | @bit /i/@ is a value with the @/i/@th bit set and all other bits clear. - bit :: Exp Int -> Exp a + bit :: Exp a -> Exp a -- | @x \`setBit\` i@ is the same as @x .|. bit i@ - setBit :: Exp a -> Exp Int -> Exp a + setBit :: Exp a -> Exp a -> Exp a setBit x i = x .|. bit i -- | @x \`clearBit\` i@ is the same as @x .&. complement (bit i)@ - clearBit :: Exp a -> Exp Int -> Exp a + clearBit :: Exp a -> Exp a -> Exp a clearBit x i = x .&. complement (bit i) -- | @x \`complementBit\` i@ is the same as @x \`xor\` bit i@ - complementBit :: Exp a -> Exp Int -> Exp a + complementBit :: Exp a -> Exp a -> Exp a complementBit x i = x `xor` bit i -- | Return 'True' if the @n@th bit of the argument is 1 - testBit :: Exp a -> Exp Int -> Exp Bool + testBit :: Exp a -> Exp a -> Exp (Bools a) -- | Return 'True' if the argument is a signed type. isSigned :: Exp a -> Exp Bool -- | Shift the argument left by the specified number of bits (which must be - -- non-negative). - shiftL :: Exp a -> Exp Int -> Exp a - shiftL x i = x `shift` i + -- non-negative) + shiftL :: Exp a -> Exp a -> Exp a -- | Shift the argument left by the specified number of bits. The result is -- undefined for negative shift amounts and shift amounts greater or equal to -- the 'finiteBitSize'. - unsafeShiftL :: Exp a -> Exp Int -> Exp a + unsafeShiftL :: Exp a -> Exp a -> Exp a unsafeShiftL = shiftL -- | Shift the first argument right by the specified number of bits (which @@ -125,39 +122,36 @@ class Eq a => Bits a where -- -- Right shifts perform sign extension on signed number types; i.e. they fill -- the top bits with 1 if @x@ is negative and with 0 otherwise. - shiftR :: Exp a -> Exp Int -> Exp a - shiftR x i = x `shift` (-i) + shiftR :: Exp a -> Exp a -> Exp a -- | Shift the first argument right by the specified number of bits. The -- result is undefined for negative shift amounts and shift amounts greater or -- equal to the 'finiteBitSize'. - unsafeShiftR :: Exp a -> Exp Int -> Exp a + unsafeShiftR :: Exp a -> Exp a -> Exp a unsafeShiftR = shiftR -- | Rotate the argument left by the specified number of bits (which must be -- non-negative). - rotateL :: Exp a -> Exp Int -> Exp a - rotateL x i = x `rotate` i + rotateL :: Exp a -> Exp a -> Exp a -- | Rotate the argument right by the specified number of bits (which must be non-negative). - rotateR :: Exp a -> Exp Int -> Exp a - rotateR x i = x `rotate` (-i) + rotateR :: Exp a -> Exp a -> Exp a -- | Return the number of set bits in the argument. This number is known as -- the population count or the Hamming weight. - popCount :: Exp a -> Exp Int + popCount :: Exp a -> Exp a -class Bits b => FiniteBits b where +class Bits a => FiniteBits a where -- | Return the number of bits in the type of the argument. - finiteBitSize :: Exp b -> Exp Int + finiteBitSize :: Exp a -> Exp Int -- | Count the number of zero bits preceding the most significant set bit. -- This can be used to compute a base-2 logarithm via: -- -- > logBase2 x = finiteBitSize x - 1 - countLeadingZeros x -- - countLeadingZeros :: Exp b -> Exp Int + countLeadingZeros :: Exp a -> Exp a -- | Count the number of zero bits following the least significant set bit. -- The related @@ -166,558 +160,89 @@ class Bits b => FiniteBits b where -- -- > findFirstSet x = 1 + countTrailingZeros x -- - countTrailingZeros :: Exp b -> Exp Int + countTrailingZeros :: Exp a -> Exp a + + -- | Reverse the order of bits + -- + -- @since 1.4.0.0 + -- + bitreverse :: Exp a -> Exp a + + -- | Reverse the order of bytes + -- + -- @since 1.4.0.0 + -- + byteswap :: Exp a -> Exp a -- Instances for Bits -- ------------------ instance Bits Bool where + type Bools Bool = Bool (.&.) = (&&) (.|.) = (||) xor = (/=) complement = not - shift x i = cond (i == 0) x (constant False) - testBit x i = cond (i == 0) x (constant False) - rotate x _ = x - bit i = i == 0 - isSigned = isSignedDefault - popCount = boolToInt - -instance Bits Int where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Int8 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Int16 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Int32 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Int64 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word8 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word16 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word32 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits Word64 where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = bitDefault - testBit = testBitDefault - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault + zeroBits = False + testBit = (&&) + bit x = x isSigned = isSignedDefault - popCount = mkPopCount - -instance Bits CInt where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Int32 - testBit b = testBitDefault (mkBitcast @Int32 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Int32 - -instance Bits CUInt where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Word32 - testBit b = testBitDefault (mkBitcast @Word32 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Word32 - -instance Bits CLong where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @HTYPE_CLONG - testBit b = testBitDefault (mkBitcast @HTYPE_CLONG b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @HTYPE_CLONG - -instance Bits CULong where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @HTYPE_CULONG - testBit b = testBitDefault (mkBitcast @HTYPE_CULONG b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @HTYPE_CULONG - -instance Bits CLLong where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Int64 - testBit b = testBitDefault (mkBitcast @Int64 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Int64 - -instance Bits CULLong where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Word64 - testBit b = testBitDefault (mkBitcast @Word64 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Word64 - -instance Bits CShort where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Int16 - testBit b = testBitDefault (mkBitcast @Int16 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Int16 - -instance Bits CUShort where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Word16 - testBit b = testBitDefault (mkBitcast @Word16 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Word16 - -instance Bits CChar where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @HTYPE_CCHAR - testBit b = testBitDefault (mkBitcast @HTYPE_CCHAR b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @HTYPE_CCHAR - -instance Bits CSChar where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Int8 - testBit b = testBitDefault (mkBitcast @Int8 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Int8 - -instance Bits CUChar where - (.&.) = mkBAnd - (.|.) = mkBOr - xor = mkBXor - complement = mkBNot - bit = mkBitcast . bitDefault @Word8 - testBit b = testBitDefault (mkBitcast @Word8 b) - shift = shiftDefault - shiftL = shiftLDefault - shiftR = shiftRDefault - unsafeShiftL = mkBShiftL - unsafeShiftR = mkBShiftR - rotate = rotateDefault - rotateL = rotateLDefault - rotateR = rotateRDefault - isSigned = isSignedDefault - popCount = mkPopCount . mkBitcast @Word8 - - - --- Instances for FiniteBits --- ------------------------ + shiftL x i = cond i False x + shiftR x i = cond i False x + rotateL x _ = x + rotateR x _ = x + popCount x = x instance FiniteBits Bool where - finiteBitSize _ = constInt 8 -- stored as Word8 {- (B.finiteBitSize (undefined::Bool)) -} - countLeadingZeros x = cond x 0 1 - countTrailingZeros x = cond x 0 1 - -instance FiniteBits Int where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Int8 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int8)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Int16 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int16)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Int32 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int32)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Int64 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Int64)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word8 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word8)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word16 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word16)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word32 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word32)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits Word64 where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::Word64)) - countLeadingZeros = mkCountLeadingZeros - countTrailingZeros = mkCountTrailingZeros - -instance FiniteBits CInt where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CInt)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int32 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int32 - -instance FiniteBits CUInt where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUInt)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word32 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word32 - -instance FiniteBits CLong where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CLong)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CLONG - countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CLONG - -instance FiniteBits CULong where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CULong)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CULONG - countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CULONG - -instance FiniteBits CLLong where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CLLong)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int64 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int64 - -instance FiniteBits CULLong where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CULLong)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word64 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word64 - -instance FiniteBits CShort where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CShort)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int16 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int16 - -instance FiniteBits CUShort where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUShort)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word16 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word16 - -instance FiniteBits CChar where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CChar)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @HTYPE_CCHAR - countTrailingZeros = mkCountTrailingZeros . mkBitcast @HTYPE_CCHAR - -instance FiniteBits CSChar where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CSChar)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Int8 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Int8 - -instance FiniteBits CUChar where - finiteBitSize _ = constInt (B.finiteBitSize (undefined::CUChar)) - countLeadingZeros = mkCountLeadingZeros . mkBitcast @Word8 - countTrailingZeros = mkCountTrailingZeros . mkBitcast @Word8 + finiteBitSize _ = fromInteger (toInteger (B.finiteBitSize (undefined::Bool))) + countLeadingZeros x = x + countTrailingZeros x = x + bitreverse x = x + byteswap x = x -- Default implementations -- ----------------------- -bitDefault :: (IsIntegral (EltR t), Bits t) => Exp Int -> Exp t -bitDefault x = constInt 1 `shiftL` x -testBitDefault :: (IsIntegral (EltR t), Bits t) => Exp t -> Exp Int -> Exp Bool -testBitDefault x i = (x .&. bit i) /= constInt 0 +bitDefault :: (Num t, Bits t) => Exp t -> Exp t +bitDefault x = 1 `shiftL` x -shiftDefault :: (FiniteBits t, IsIntegral (EltR t), B.Bits t) => Exp t -> Exp Int -> Exp t -shiftDefault x i - = cond (i >= 0) (shiftLDefault x i) - (shiftRDefault x (-i)) +testBitDefault :: (Num t, Bits t) => Exp t -> Exp t -> Exp Bool +testBitDefault x i = (x .&. bit i) /= 0 -shiftLDefault :: (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +-- shiftDefault :: (FiniteBits t, IsIntegral (EltR t), B.Bits t) => Exp t -> Exp t -> Exp t +-- shiftDefault x i +-- = cond (i >= 0) (shiftLDefault x i) +-- (shiftRDefault x (-i)) + +shiftLDefault :: forall t. (B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t shiftLDefault x i - = cond (i >= finiteBitSize x) (constInt 0) + = cond (i >= P.fromIntegral (B.finiteBitSize (undefined::t))) 0 $ mkBShiftL x i -shiftRDefault :: forall t. (B.Bits t, FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +shiftRDefault :: forall t. (B.Bits t, B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp t shiftRDefault | B.isSigned (undefined::t) = shiftRADefault | otherwise = shiftRLDefault -- Shift the argument right (signed) -shiftRADefault :: (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +shiftRADefault :: forall t. (B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsScalar (EltR t), IsIntegral (EltR t), BitOrMask (EltR t) ~ Bit) => Exp t -> Exp t -> Exp t shiftRADefault x i - = cond (i >= finiteBitSize x) (cond (mkLt x (constInt 0)) (constInt (-1)) (constInt 0)) + = cond (i >= P.fromIntegral (B.finiteBitSize (undefined::t))) (cond (x < 0) (-1) 0) $ mkBShiftR x i -- Shift the argument right (unsigned) -shiftRLDefault :: (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t +shiftRLDefault :: forall t. (B.FiniteBits t, Num t, Ord t, FromIntegral Int t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t shiftRLDefault x i - = cond (i >= finiteBitSize x) (constInt 0) + = cond (i >= P.fromIntegral (B.finiteBitSize (undefined::t))) 0 $ mkBShiftR x i -rotateDefault :: forall t. (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -rotateDefault x i - = cond (i < 0) (mkBRotateR x (-i)) - $ cond (i > 0) (mkBRotateL x i) - $ x +-- rotateDefault :: forall t. (FiniteBits t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +-- rotateDefault x i +-- = cond (i < 0) (mkBRotateR x (-i)) +-- $ cond (i > 0) (mkBRotateL x i) +-- $ x {-- -- Rotation can be implemented in terms of two shifts, but care is needed @@ -759,22 +284,19 @@ rotateDefault' _ x i wsib = finiteBitSize x --} -rotateLDefault :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -rotateLDefault x i - = cond (i == 0) x - $ mkBRotateL x i +-- rotateLDefault :: (Num t, Eq t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +-- rotateLDefault x i +-- = cond (i == 0) x +-- $ mkBRotateL x i -rotateRDefault :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -rotateRDefault x i - = cond (i == 0) x - $ mkBRotateR x i +-- rotateRDefault :: (Num t, Eq t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t +-- rotateRDefault x i +-- = cond (i == 0) x +-- $ mkBRotateR x i isSignedDefault :: forall b. B.Bits b => Exp b -> Exp Bool isSignedDefault _ = constant (B.isSigned (undefined::b)) -constInt :: IsIntegral (EltR e) => EltR e -> Exp e -constInt = mkExp . Const (SingleScalarType (NumSingleType (IntegralNumType integralType))) - {-- _popCountDefault :: forall a. (B.FiniteBits a, IsScalar a, Bits a, Num a) => Exp a -> Exp Int _popCountDefault = @@ -828,3 +350,116 @@ popCnt64 v1 = mkFromIntegral c c = (v4 * 0x0101010101010101) `unsafeShiftR` 56 --} +mkBAnd :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBAnd = mkPrimBinary $ PrimBAnd integralType + +mkBOr :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBOr = mkPrimBinary $ PrimBOr integralType + +mkBXor :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBXor = mkPrimBinary $ PrimBXor integralType + +mkBNot :: IsIntegral (EltR t) => Exp t -> Exp t +mkBNot = mkPrimUnary $ PrimBNot integralType + +mkBShiftL :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBShiftL = mkPrimBinary $ PrimBShiftL integralType + +mkBShiftR :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBShiftR = mkPrimBinary $ PrimBShiftR integralType + +mkBRotateL :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBRotateL = mkPrimBinary $ PrimBRotateL integralType + +mkBRotateR :: IsIntegral (EltR t) => Exp t -> Exp t -> Exp t +mkBRotateR = mkPrimBinary $ PrimBRotateR integralType + +mkPopCount :: IsIntegral (EltR t) => Exp t -> Exp t +mkPopCount = mkPrimUnary $ PrimPopCount integralType + +mkCountLeadingZeros :: IsIntegral (EltR t) => Exp t -> Exp t +mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType + +mkCountTrailingZeros :: IsIntegral (EltR t) => Exp t -> Exp t +mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType + +mkBReverse :: IsIntegral (EltR t) => Exp t -> Exp t +mkBReverse = mkPrimUnary $ PrimBReverse integralType + +mkBSwap :: IsIntegral (EltR t) => Exp t -> Exp t +mkBSwap = mkPrimUnary $ PrimBSwap integralType + + +runQ $ + let + integralTypes :: [Name] + integralTypes = + [ ''Int + , ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + thBits :: Name -> Q [Dec] + thBits a = + [d| instance Bits $(conT a) where + type Bools $(conT a) = Bool + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot + bit = bitDefault + testBit = testBitDefault + shiftL = shiftLDefault + shiftR = shiftRDefault + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotateL = mkBRotateL + rotateR = mkBRotateR + isSigned = isSignedDefault + popCount = mkPopCount + + instance KnownNat n => Bits (Vec n $(conT a)) where + type Bools (Vec n $(conT a)) = (Vec n Bool) + (.&.) = mkBAnd + (.|.) = mkBOr + xor = mkBXor + complement = mkBNot + bit = bitDefault + testBit x i = (x .&. bit i) /=* 0 + shiftL x i = select (i >=* P.fromIntegral (B.finiteBitSize (undefined :: $(conT a)))) 0 (mkBShiftL x i) + shiftR x i + | B.isSigned (undefined :: $(conT a)) = select (i >=* P.fromIntegral (B.finiteBitSize (undefined :: $(conT a)))) (select (x <* 0) (P.fromInteger (-1)) 0) (mkBShiftR x i) + | otherwise = select (i >=* P.fromIntegral (B.finiteBitSize (undefined :: $(conT a)))) 0 (mkBShiftR x i) + unsafeShiftL = mkBShiftL + unsafeShiftR = mkBShiftR + rotateL = mkBRotateL + rotateR = mkBRotateR + isSigned _ = constant (B.isSigned (undefined :: $(conT a))) + popCount = mkPopCount + + instance FiniteBits $(conT a) where + finiteBitSize _ = fromInteger (toInteger (B.finiteBitSize (undefined :: $(conT a)))) + countLeadingZeros = mkCountLeadingZeros + countTrailingZeros = mkCountTrailingZeros + bitreverse = mkBReverse + byteswap = mkBSwap + + instance KnownNat n => FiniteBits (Vec n $(conT a)) where + finiteBitSize _ = fromInteger (natVal' (proxy# :: Proxy# n) * toInteger (B.finiteBitSize (undefined :: $(conT a)))) + countLeadingZeros = mkCountLeadingZeros + countTrailingZeros = mkCountTrailingZeros + bitreverse = mkBReverse + byteswap = mkBSwap + |] + in + concat <$> mapM thBits integralTypes + diff --git a/src/Data/Array/Accelerate/Data/Complex.hs b/src/Data/Array/Accelerate/Data/Complex.hs index 1ac4ba771..6872c1db7 100644 --- a/src/Data/Array/Accelerate/Data/Complex.hs +++ b/src/Data/Array/Accelerate/Data/Complex.hs @@ -1,19 +1,20 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeSynonymInstances #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Data.Complex @@ -30,7 +31,7 @@ module Data.Array.Accelerate.Data.Complex ( -- * Rectangular from - Complex(..), pattern (::+), + Complex, pattern (:+), real, imag, @@ -46,6 +47,8 @@ module Data.Array.Accelerate.Data.Complex ( ) where +import Data.Array.Accelerate.AST ( PrimFun(..), BitOrMask ) +import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Classes.Eq import Data.Array.Accelerate.Classes.Floating import Data.Array.Accelerate.Classes.Fractional @@ -55,17 +58,17 @@ import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Classes.RealFloat import Data.Array.Accelerate.Data.Functor import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Vec -import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Smart hiding ( pack, unpack ) import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type -import Data.Primitive.Vec +import qualified Data.Primitive.Vec as Prim -import Data.Complex ( Complex(..) ) +import Data.Complex ( Complex ) +import Data.Primitive.Types import Prelude ( ($) ) import qualified Data.Complex as C import qualified Prelude as P @@ -74,30 +77,43 @@ import Data.Type.Equality #endif -infix 6 ::+ -pattern (::+) :: Elt a => Exp a -> Exp a -> Exp (Complex a) -pattern r ::+ i <- (deconstructComplex -> (r, i)) - where (::+) = constructComplex -{-# COMPLETE (::+) #-} +infix 6 :+ +pattern (:+) :: IsComplex a b => a -> a -> b +pattern r :+ i <- (matchComplex -> (r,i)) + where (:+) = buildComplex +{-# COMPLETE (:+) :: Complex #-} +{-# COMPLETE (:+) :: Exp #-} + +class IsComplex a b | b -> a where + matchComplex :: b -> (a, a) + buildComplex :: a -> a -> b + +instance IsComplex a (Complex a) where + buildComplex = (C.:+) + matchComplex (r C.:+ i) = (r, i) -- Use an array-of-structs representation for complex numbers if possible. --- This matches the standard C-style layout, but we can use this representation only at --- specific types (not for any type 'a') as we can only have vectors of primitive type. --- For other types, we use a structure-of-arrays representation. This is handled by the --- ComplexR. We use the GADT ComplexR and function complexR to reconstruct --- information on how the elements are represented. +-- +-- This matches the standard C-style layout, but we can use this representation +-- only at specific types (not for any type 'a') as we can only have vectors of +-- primitive type. For other types, we use a structure-of-arrays representation. +-- This is handled by the ComplexR. We use the GADT ComplexR and function +-- complexR to reconstruct information on how the elements are represented. +-- +-- TODO: This is no longer true, we could SIMD-ify more types here. +-- - TLM 2023-09-28 -- instance Elt a => Elt (Complex a) where type EltR (Complex a) = ComplexR (EltR a) eltR = let tR = eltR @a in case complexR tR of - ComplexVec s -> TupRsingle $ VectorScalarType $ VectorType 2 s + ComplexVec t -> TupRsingle (NumScalarType t) ComplexTup -> TupRunit `TupRpair` tR `TupRpair` tR tagsR = let tR = eltR @a in case complexR tR of - ComplexVec s -> [ TagRsingle (VectorScalarType (VectorType 2 s)) ] + ComplexVec t -> [ TagRsingle (NumScalarType t) ] ComplexTup -> let go :: TypeR t -> [TagR t] go TupRunit = [TagRunit] go (TupRsingle s) = [TagRsingle s] @@ -106,36 +122,33 @@ instance Elt a => Elt (Complex a) where [ TagRunit `TagRpair` ta `TagRpair` tb | ta <- go tR, tb <- go tR ] toElt = case complexR $ eltR @a of - ComplexVec _ -> \(Vec2 r i) -> toElt r :+ toElt i - ComplexTup -> \(((), r), i) -> toElt r :+ toElt i + ComplexVec _ -> \(Prim.V2 r i) -> toElt r :+ toElt i + ComplexTup -> \(((), r), i) -> toElt r :+ toElt i fromElt (r :+ i) = case complexR $ eltR @a of - ComplexVec _ -> Vec2 (fromElt r) (fromElt i) + ComplexVec _ -> Prim.V2 (fromElt r) (fromElt i) ComplexTup -> (((), fromElt r), fromElt i) type family ComplexR a where - ComplexR Half = Vec2 Half - ComplexR Float = Vec2 Float - ComplexR Double = Vec2 Double - ComplexR Int = Vec2 Int - ComplexR Int8 = Vec2 Int8 - ComplexR Int16 = Vec2 Int16 - ComplexR Int32 = Vec2 Int32 - ComplexR Int64 = Vec2 Int64 - ComplexR Word = Vec2 Word - ComplexR Word8 = Vec2 Word8 - ComplexR Word16 = Vec2 Word16 - ComplexR Word32 = Vec2 Word32 - ComplexR Word64 = Vec2 Word64 - ComplexR a = (((), a), a) - --- This isn't ideal because we gather the evidence based on the --- representation type, so we really get the evidence (VecElt (EltR a)), --- which is not very useful... --- - TLM 2020-07-16 + ComplexR Half = Prim.V2 Float16 + ComplexR Float = Prim.V2 Float32 + ComplexR Double = Prim.V2 Float64 + ComplexR Float128 = Prim.V2 Float128 + ComplexR Int8 = Prim.V2 Int8 + ComplexR Int16 = Prim.V2 Int16 + ComplexR Int32 = Prim.V2 Int32 + ComplexR Int64 = Prim.V2 Int64 + ComplexR Int128 = Prim.V2 Int128 + ComplexR Word8 = Prim.V2 Word8 + ComplexR Word16 = Prim.V2 Word16 + ComplexR Word32 = Prim.V2 Word32 + ComplexR Word64 = Prim.V2 Word64 + ComplexR Word128 = Prim.V2 Word128 + ComplexR a = (((), a), a) + data ComplexType a c where - ComplexVec :: VecElt a => SingleType a -> ComplexType a (Vec2 a) - ComplexTup :: ComplexType a (((), a), a) + ComplexVec :: Prim a => NumType (Prim.V2 a) -> ComplexType a (Prim.V2 a) + ComplexTup :: ComplexType a (((), a), a) complexR :: TypeR a -> ComplexType a (ComplexR a) complexR = tuple @@ -146,96 +159,184 @@ complexR = tuple tuple (TupRsingle s) = scalar s scalar :: ScalarType a -> ComplexType a (ComplexR a) - scalar (SingleScalarType t) = single t - scalar VectorScalarType{} = ComplexTup + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - single :: SingleType a -> ComplexType a (ComplexR a) - single (NumSingleType t) = num t + bit :: BitType t -> ComplexType t (ComplexR t) + bit TypeBit = ComplexTup + bit TypeMask{} = ComplexTup num :: NumType a -> ComplexType a (ComplexR a) num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType a -> ComplexType a (ComplexR a) - integral TypeInt = ComplexVec singleType - integral TypeInt8 = ComplexVec singleType - integral TypeInt16 = ComplexVec singleType - integral TypeInt32 = ComplexVec singleType - integral TypeInt64 = ComplexVec singleType - integral TypeWord = ComplexVec singleType - integral TypeWord8 = ComplexVec singleType - integral TypeWord16 = ComplexVec singleType - integral TypeWord32 = ComplexVec singleType - integral TypeWord64 = ComplexVec singleType + integral = \case + VectorIntegralType{} -> ComplexTup + SingleIntegralType t -> case t of + TypeInt8 -> ComplexVec numType + TypeInt16 -> ComplexVec numType + TypeInt32 -> ComplexVec numType + TypeInt64 -> ComplexVec numType + TypeInt128 -> ComplexVec numType + TypeWord8 -> ComplexVec numType + TypeWord16 -> ComplexVec numType + TypeWord32 -> ComplexVec numType + TypeWord64 -> ComplexVec numType + TypeWord128 -> ComplexVec numType floating :: FloatingType a -> ComplexType a (ComplexR a) - floating TypeHalf = ComplexVec singleType - floating TypeFloat = ComplexVec singleType - floating TypeDouble = ComplexVec singleType - - -constructComplex :: forall a. Elt a => Exp a -> Exp a -> Exp (Complex a) -constructComplex r i = - case complexR (eltR @a) of - ComplexTup -> coerce $ T2 r i - ComplexVec _ -> V2 (coerce @a @(EltR a) r) (coerce @a @(EltR a) i) - -deconstructComplex :: forall a. Elt a => Exp (Complex a) -> (Exp a, Exp a) -deconstructComplex c@(Exp c') = - case complexR (eltR @a) of - ComplexTup -> let T2 r i = coerce c in (r, i) - ComplexVec t -> let T2 r i = Exp (SmartExp (VecUnpack (VecRsucc (VecRsucc (VecRnil t))) c')) - in (r, i) + floating = \case + VectorFloatingType{} -> ComplexTup + SingleFloatingType t -> case t of + TypeFloat16 -> ComplexVec numType + TypeFloat32 -> ComplexVec numType + TypeFloat64 -> ComplexVec numType + TypeFloat128 -> ComplexVec numType + + +instance Elt a => IsComplex (Exp a) (Exp (Complex a)) where + matchComplex (Exp c) = + case complexR (eltR @a) of + ComplexTup -> + let i = SmartExp (Prj PairIdxRight c) + r = SmartExp (Prj PairIdxRight (SmartExp (Prj PairIdxLeft c))) + in (Exp r, Exp i) + ComplexVec t -> + let (r, i) = num t c + in (Exp r, Exp i) + where + num :: NumType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType n t) = + let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) + in case t of + TypeInt8 -> unpack v + TypeInt16 -> unpack v + TypeInt32 -> unpack v + TypeInt64 -> unpack v + TypeInt128 -> unpack v + TypeWord8 -> unpack v + TypeWord16 -> unpack v + TypeWord32 -> unpack v + TypeWord64 -> unpack v + TypeWord128 -> unpack v + + floating :: FloatingType (Prim.V2 t) -> SmartExp (ComplexR t) -> (SmartExp t, SmartExp t) + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType n t) = + let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) + in case t of + TypeFloat16 -> unpack v + TypeFloat32 -> unpack v + TypeFloat64 -> unpack v + TypeFloat128 -> unpack v + + unpack :: ScalarType (Prim.Vec 2 t) -> SmartExp (Prim.Vec 2 t) -> (SmartExp t, SmartExp t) + unpack v x = + let r = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 0))) + i = SmartExp (Extract v TypeWord8 x (SmartExp (Const scalarType 1))) + in + (r, i) + + buildComplex r@(Exp r') i@(Exp i') = + case complexR (eltR @a) of + ComplexTup -> Pattern (r,i) + ComplexVec t -> Exp $ num t r' i' + where + num :: NumType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType n t) = + let v = NumScalarType (IntegralNumType (VectorIntegralType n t)) + in case t of + TypeInt8 -> pack v + TypeInt16 -> pack v + TypeInt32 -> pack v + TypeInt64 -> pack v + TypeInt128 -> pack v + TypeWord8 -> pack v + TypeWord16 -> pack v + TypeWord32 -> pack v + TypeWord64 -> pack v + TypeWord128 -> pack v + + floating :: FloatingType (Prim.V2 t) -> SmartExp t -> SmartExp t -> SmartExp (ComplexR t) + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType n t) = + let v = NumScalarType (FloatingNumType (VectorFloatingType n t)) + in case t of + TypeFloat16 -> pack v + TypeFloat32 -> pack v + TypeFloat64 -> pack v + TypeFloat128 -> pack v + + pack :: ScalarType (Prim.Vec 2 t) -> SmartExp t -> SmartExp t -> SmartExp (Prim.Vec 2 t) + pack v x y + = SmartExp (Insert v TypeWord8 + (SmartExp (Insert v TypeWord8 (SmartExp (Undef v)) (SmartExp (Const scalarType 0)) x)) + (SmartExp (Const scalarType 1)) y) -coerce :: EltR a ~ EltR b => Exp a -> Exp b -coerce (Exp e) = Exp e instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Complex a) where type Plain (Complex a) = Complex (Plain a) - lift (r :+ i) = lift r ::+ lift i + lift (r :+ i) = lift r :+ lift i instance Elt a => Unlift Exp (Complex (Exp a)) where - unlift (r ::+ i) = r :+ i + unlift (r :+ i) = r :+ i instance Eq a => Eq (Complex a) where - r1 ::+ c1 == r2 ::+ c2 = r1 == r2 && c1 == c2 - r1 ::+ c1 /= r2 ::+ c2 = r1 /= r2 || c1 /= c2 - -instance RealFloat a => P.Num (Exp (Complex a)) where - (+) = lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) - (-) = lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) - (*) = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) - negate = lift1 (negate :: Complex (Exp a) -> Complex (Exp a)) - signum z@(x ::+ y) = + r1 :+ c1 == r2 :+ c2 = r1 == r2 && c1 == c2 + r1 :+ c1 /= r2 :+ c2 = r1 /= r2 || c1 /= c2 + +instance (RealFloat a, Exponent a ~ Int) => P.Num (Exp (Complex a)) where + (+) = case complexR (eltR @a) of + ComplexTup -> lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) + ComplexVec t -> mkPrimBinary $ PrimAdd t + (-) = case complexR (eltR @a) of + ComplexTup -> lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) + ComplexVec t -> mkPrimBinary $ PrimSub t + (*) = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a)) + negate = case complexR (eltR @a) of + ComplexTup -> lift1 (negate :: Complex (Exp a) -> Complex (Exp a)) + ComplexVec t -> mkPrimUnary $ PrimNeg t + signum z@(x :+ y) = if z == 0 then z else let r = magnitude z - in x/r ::+ y/r - abs z = magnitude z ::+ 0 - fromInteger n = fromInteger n ::+ 0 + in x/r :+ y/r + abs z = magnitude z :+ 0 + fromInteger n = fromInteger n :+ 0 -instance RealFloat a => P.Fractional (Exp (Complex a)) where - fromRational x = fromRational x ::+ 0 - z / z' = (x*x''+y*y'') / d ::+ (y*x''-x*y'') / d +instance (RealFloat a, Exponent a ~ Int) => P.Fractional (Exp (Complex a)) where + fromRational x = fromRational x :+ 0 + z / z' = (x*x''+y*y'') / d :+ (y*x''-x*y'') / d where - x :+ y = unlift z - x' :+ y' = unlift z' + x :+ y = z + x' :+ y' = z' -- x'' = scaleFloat k x' y'' = scaleFloat k y' k = - max (exponent x') (exponent y') d = x'*x'' + y'*y'' -instance RealFloat a => P.Floating (Exp (Complex a)) where - pi = pi ::+ 0 - exp (x ::+ y) = let expx = exp x - in expx * cos y ::+ expx * sin y - log z = log (magnitude z) ::+ phase z - sqrt z@(x ::+ y) = +instance (RealFloat a, Exponent a ~ Int, BitOrMask (EltR a) ~ Bit) => P.Floating (Exp (Complex a)) where + pi = pi :+ 0 + exp (x :+ y) = let expx = exp x + in expx * cos y :+ expx * sin y + log z = log (magnitude z) :+ phase z + sqrt z@(x :+ y) = if z == 0 then 0 - else u ::+ (y < 0 ? (-v, v)) + else u :+ (y < 0 ? (-v, v)) where T2 u v = x < 0 ? (T2 v' u', T2 u' v') v' = abs y / (u'*2) @@ -244,50 +345,50 @@ instance RealFloat a => P.Floating (Exp (Complex a)) where x ** y = if y == 0 then 1 else if x == 0 then if exp_r > 0 then 0 else - if exp_r < 0 then inf ::+ 0 - else nan ::+ nan + if exp_r < 0 then inf :+ 0 + else nan :+ nan else if isInfinite r || isInfinite i - then if exp_r > 0 then inf ::+ 0 else + then if exp_r > 0 then inf :+ 0 else if exp_r < 0 then 0 - else nan ::+ nan + else nan :+ nan else exp (log x * y) where - r ::+ i = x - exp_r ::+ _ = y + r :+ i = x + exp_r :+ _ = y -- inf = 1 / 0 nan = 0 / 0 - sin (x ::+ y) = sin x * cosh y ::+ cos x * sinh y - cos (x ::+ y) = cos x * cosh y ::+ (- sin x * sinh y) - tan (x ::+ y) = (sinx*coshy ::+ cosx*sinhy) / (cosx*coshy ::+ (-sinx*sinhy)) + sin (x :+ y) = sin x * cosh y :+ cos x * sinh y + cos (x :+ y) = cos x * cosh y :+ (- sin x * sinh y) + tan (x :+ y) = (sinx*coshy :+ cosx*sinhy) / (cosx*coshy :+ (-sinx*sinhy)) where sinx = sin x cosx = cos x sinhy = sinh y coshy = cosh y - sinh (x ::+ y) = cos y * sinh x ::+ sin y * cosh x - cosh (x ::+ y) = cos y * cosh x ::+ sin y * sinh x - tanh (x ::+ y) = (cosy*sinhx ::+ siny*coshx) / (cosy*coshx ::+ siny*sinhx) + sinh (x :+ y) = cos y * sinh x :+ sin y * cosh x + cosh (x :+ y) = cos y * cosh x :+ sin y * sinh x + tanh (x :+ y) = (cosy*sinhx :+ siny*coshx) / (cosy*coshx :+ siny*sinhx) where siny = sin y cosy = cos y sinhx = sinh x coshx = cosh x - asin z@(x ::+ y) = y' ::+ (-x') + asin z@(x :+ y) = y' :+ (-x') where - x' ::+ y' = log (((-y) ::+ x) + sqrt (1 - z*z)) + x' :+ y' = log (((-y) :+ x) + sqrt (1 - z*z)) - acos z = y'' ::+ (-x'') + acos z = y'' :+ (-x'') where - x'' ::+ y'' = log (z + ((-y') ::+ x')) - x' ::+ y' = sqrt (1 - z*z) + x'' :+ y'' = log (z + ((-y') :+ x')) + x' :+ y' = sqrt (1 - z*z) - atan z@(x ::+ y) = y' ::+ (-x') + atan z@(x :+ y) = y' :+ (-x') where - x' ::+ y' = log (((1-y) ::+ x) / sqrt (1+z*z)) + x' :+ y' = log (((1-y) :+ x) / sqrt (1+z*z)) asinh z = log (z + sqrt (1+z*z)) acosh z = log (z + (z+1) * sqrt ((z-1)/(z+1))) @@ -295,18 +396,18 @@ instance RealFloat a => P.Floating (Exp (Complex a)) where instance (FromIntegral a b, Num b, Elt (Complex b)) => FromIntegral a (Complex b) where - fromIntegral x = fromIntegral x ::+ 0 + fromIntegral x = fromIntegral x :+ 0 -- | @since 1.2.0.0 -- instance Functor Complex where - fmap f (r ::+ i) = f r ::+ f i + fmap f (r :+ i) = f r :+ f i -- | The non-negative magnitude of a complex number -- -magnitude :: RealFloat a => Exp (Complex a) -> Exp a -magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i))) +magnitude :: (RealFloat a, Exponent a ~ Int) => Exp (Complex a) -> Exp a +magnitude (r :+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i))) where k = max (exponent r) (exponent i) mk = -k @@ -318,14 +419,14 @@ magnitude (r ::+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloa -- @since 1.3.0.0 -- magnitude' :: RealFloat a => Exp (Complex a) -> Exp a -magnitude' (r ::+ i) = sqrt (r*r + i*i) +magnitude' (r :+ i) = sqrt (r*r + i*i) -- | The phase of a complex number, in the range @(-'pi', 'pi']@. If the -- magnitude is zero, then so is the phase. -- phase :: RealFloat a => Exp (Complex a) -> Exp a -phase z@(r ::+ i) = - if z == 0 +phase (r :+ i) = + if r == 0 && i == 0 then 0 else atan2 i r @@ -333,7 +434,7 @@ phase z@(r ::+ i) = -- phase) pair in canonical form: the magnitude is non-negative, and the phase -- in the range @(-'pi', 'pi']@; if the magnitude is zero, then so is the phase. -- -polar :: RealFloat a => Exp (Complex a) -> Exp (a,a) +polar :: (RealFloat a, Exponent a ~ Int) => Exp (Complex a) -> Exp (a,a) polar z = T2 (magnitude z) (phase z) -- | Form a complex number from polar components of magnitude and phase. @@ -350,17 +451,17 @@ cis = lift1 (C.cis :: Exp a -> Complex (Exp a)) -- | Return the real part of a complex number -- real :: Elt a => Exp (Complex a) -> Exp a -real (r ::+ _) = r +real (r :+ _) = r -- | Return the imaginary part of a complex number -- imag :: Elt a => Exp (Complex a) -> Exp a -imag (_ ::+ i) = i +imag (_ :+ i) = i -- | Return the complex conjugate of a complex number, defined as -- -- > conjugate(Z) = X - iY -- conjugate :: Num a => Exp (Complex a) -> Exp (Complex a) -conjugate z = real z ::+ (- imag z) +conjugate z = real z :+ (- imag z) diff --git a/src/Data/Array/Accelerate/Data/Either.hs b/src/Data/Array/Accelerate/Data/Either.hs index 3c5c7401d..a2c4cf52e 100644 --- a/src/Data/Array/Accelerate/Data/Either.hs +++ b/src/Data/Array/Accelerate/Data/Either.hs @@ -28,11 +28,12 @@ module Data.Array.Accelerate.Data.Either ( - Either(..), pattern Left_, pattern Right_, + Either, pattern Left, pattern Right, either, isLeft, isRight, fromLeft, fromRight, lefts, rights, ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift @@ -52,7 +53,6 @@ import Data.Array.Accelerate.Data.Functor import Data.Array.Accelerate.Data.Monoid import Data.Array.Accelerate.Data.Semigroup -import Data.Either ( Either(..) ) import Prelude ( (.), ($) ) @@ -64,7 +64,7 @@ isLeft = not . isRight -- | Return 'True' if the argument is a 'Right'-value -- isRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp Bool -isRight (Exp e) = Exp $ SmartExp $ (SmartExp $ Prj PairIdxLeft e) `Pair` SmartExp Nil +isRight (Exp e) = mkExp $ PrimApp (PrimToBool integralType bitType) (SmartExp $ Prj PairIdxLeft e) -- TLM: This is a sneaky hack because we know that the tag bits for Right -- and True are identical. @@ -73,14 +73,14 @@ isRight (Exp e) = Exp $ SmartExp $ (SmartExp $ Prj PairIdxLeft e) `Pair` SmartEx -- instead. -- fromLeft :: (Elt a, Elt b) => Exp (Either a b) -> Exp a -fromLeft (Exp e) = Exp $ SmartExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxLeft $ SmartExp $ Prj PairIdxRight e +fromLeft (Exp e) = mkExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxLeft $ SmartExp $ Prj PairIdxRight e -- | The 'fromRight' function extracts the element out of the 'Right' -- constructor. If the argument was actually 'Left', you will get an undefined -- value instead. -- fromRight :: (Elt a, Elt b) => Exp (Either a b) -> Exp b -fromRight (Exp e) = Exp $ SmartExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxRight e +fromRight (Exp e) = mkExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxRight e -- | The 'either' function performs case analysis on the 'Either' type. If the -- value is @'Left' a@, apply the first function to @a@; if it is @'Right' b@, @@ -88,8 +88,8 @@ fromRight (Exp e) = Exp $ SmartExp $ Prj PairIdxRight $ SmartExp $ Prj PairIdxRi -- either :: (Elt a, Elt b, Elt c) => (Exp a -> Exp c) -> (Exp b -> Exp c) -> Exp (Either a b) -> Exp c either f g = match \case - Left_ x -> f x - Right_ x -> g x + Left x -> f x + Right x -> g x -- | Extract from the array of 'Either' all of the 'Left' elements, together -- with a segment descriptor indicating how many elements along each dimension @@ -111,32 +111,33 @@ rights es = compact (map isRight es) (map fromRight es) instance Elt a => Functor (Either a) where - fmap f = either Left_ (Right_ . f) + fmap f = either Left (Right . f) instance Elt a => Monad (Either a) where - return = Right_ - x >>= f = either Left_ f x + return = Right + x >>= f = either Left f x instance (Eq a, Eq b) => Eq (Either a b) where (==) = match go where - go (Left_ x) (Left_ y) = x == y - go (Right_ x) (Right_ y) = x == y - go _ _ = False_ + go (Left x) (Left y) = x == y + go (Right x) (Right y) = x == y + go _ _ = False instance (Ord a, Ord b) => Ord (Either a b) where compare = match go where - go (Left_ x) (Left_ y) = compare x y - go (Right_ x) (Right_ y) = compare x y - go Left_{} Right_{} = LT_ - go Right_{} Left_{} = GT_ + go :: Exp (Either a b) -> Exp (Either a b) -> Exp Ordering + go (Left x) (Left y) = compare x y + go (Right x) (Right y) = compare x y + go Left{} Right{} = LT + go Right{} Left{} = GT instance (Elt a, Elt b) => Semigroup (Exp (Either a b)) where ex <> ey = isLeft ex ? ( ey, ex ) instance (Lift Exp a, Lift Exp b, Elt (Plain a), Elt (Plain b)) => Lift Exp (Either a b) where type Plain (Either a b) = Either (Plain a) (Plain b) - lift (Left a) = Left_ (lift a) - lift (Right b) = Right_ (lift b) + lift (Left a) = Left (lift a) + lift (Right b) = Right (lift b) diff --git a/src/Data/Array/Accelerate/Data/Maybe.hs b/src/Data/Array/Accelerate/Data/Maybe.hs index 305da9a71..b42b5c0fc 100644 --- a/src/Data/Array/Accelerate/Data/Maybe.hs +++ b/src/Data/Array/Accelerate/Data/Maybe.hs @@ -28,11 +28,12 @@ module Data.Array.Accelerate.Data.Maybe ( - Maybe(..), pattern Nothing_, pattern Just_, + Maybe, pattern Nothing, pattern Just, maybe, isJust, isNothing, fromMaybe, fromJust, justs, ) where +import Data.Array.Accelerate.AST ( PrimFun(..) ) import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift @@ -53,7 +54,6 @@ import Data.Array.Accelerate.Data.Monoid import Data.Array.Accelerate.Data.Semigroup import Data.Function ( (&) ) -import Data.Maybe ( Maybe(..) ) import Prelude ( ($), (.) ) @@ -65,7 +65,7 @@ isNothing = not . isJust -- | Returns 'True' if the argument is of the form @Just _@ -- isJust :: Elt a => Exp (Maybe a) -> Exp Bool -isJust (Exp x) = Exp $ SmartExp $ (SmartExp $ Prj PairIdxLeft x) `Pair` SmartExp Nil +isJust (Exp x) = mkExp $ PrimApp (PrimToBool integralType bitType) (SmartExp $ Prj PairIdxLeft x) -- TLM: This is a sneaky hack because we know that the tag bits for Just -- and True are identical. @@ -75,8 +75,8 @@ isJust (Exp x) = Exp $ SmartExp $ (SmartExp $ Prj PairIdxLeft x) `Pair` SmartExp -- fromMaybe :: Elt a => Exp a -> Exp (Maybe a) -> Exp a fromMaybe d = match \case - Nothing_ -> d - Just_ x -> x + Nothing -> d + Just x -> x -- | The 'fromJust' function extracts the element out of the 'Just' constructor. -- If the argument was actually 'Nothing', you will get an undefined value @@ -92,8 +92,8 @@ fromJust (Exp x) = Exp $ SmartExp (PairIdxRight `Prj` SmartExp (PairIdxRight `Pr -- maybe :: (Elt a, Elt b) => Exp b -> (Exp a -> Exp b) -> Exp (Maybe a) -> Exp b maybe d f = match \case - Nothing_ -> d - Just_ x -> f x + Nothing -> d + Just x -> f x -- | Extract from an array all of the 'Just' values, together with a segment -- descriptor indicating how many elements along each dimension were returned. @@ -106,40 +106,44 @@ justs xs = compact (map isJust xs) (map fromJust xs) instance Functor Maybe where fmap f = match \case - Nothing_ -> Nothing_ - Just_ x -> Just_ (f x) + Nothing -> Nothing + Just x -> Just (f x) instance Monad Maybe where - return = Just_ + return = Just mx >>= f = mx & match \case - Nothing_ -> Nothing_ - Just_ x -> f x + Nothing -> Nothing + Just x -> f x instance Eq a => Eq (Maybe a) where (==) = match go where - go Nothing_ Nothing_ = True_ - go (Just_ x) (Just_ y) = x == y - go _ _ = False_ + go Nothing Nothing = True + go (Just x) (Just y) = x == y + go _ _ = False instance Ord a => Ord (Maybe a) where compare = match go where - go (Just_ x) (Just_ y) = compare x y - go Nothing_ Nothing_ = EQ_ - go Nothing_ Just_{} = LT_ - go Just_{} Nothing_{} = GT_ + go :: Exp (Maybe a) -> Exp (Maybe a) -> Exp Ordering + go (Just x) (Just y) = compare x y + go Nothing Nothing = EQ + go Nothing Just{} = LT + go Just{} Nothing{} = GT instance (Monoid (Exp a), Elt a) => Monoid (Exp (Maybe a)) where - mempty = Nothing_ + mempty = Nothing instance (Semigroup (Exp a), Elt a) => Semigroup (Exp (Maybe a)) where - ma <> mb = cond (isNothing ma) mb - $ cond (isNothing mb) ma - $ lift (Just (fromJust ma <> fromJust mb)) + (<>) = match go + where + go :: Exp (Maybe a) -> Exp (Maybe a) -> Exp (Maybe a) + go Nothing b = b + go a Nothing = a + go (Just a) (Just b) = Just (a <> b) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Maybe a) where type Plain (Maybe a) = Maybe (Plain a) - lift Nothing = Nothing_ - lift (Just a) = Just_ (lift a) + lift Nothing = Nothing + lift (Just a) = Just (lift a) diff --git a/src/Data/Array/Accelerate/Data/Monoid.hs b/src/Data/Array/Accelerate/Data/Monoid.hs index 23576be77..20a13aa76 100644 --- a/src/Data/Array/Accelerate/Data/Monoid.hs +++ b/src/Data/Array/Accelerate/Data/Monoid.hs @@ -1,16 +1,17 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} #if __GLASGOW_HASKELL__ >= 806 -{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableInstances #-} #endif -- | -- Module : Data.Array.Accelerate.Data.Monoid @@ -30,8 +31,8 @@ module Data.Array.Accelerate.Data.Monoid ( Monoid(..), (<>), - Sum(..), pattern Sum_, - Product(..), pattern Product_, + Sum, pattern Sum, + Product, pattern Product, ) where @@ -43,35 +44,51 @@ import Data.Array.Accelerate.Data.Semigroup () import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Type import Data.Function -import Data.Monoid hiding ( (<>) ) -import Data.Semigroup +import Data.Monoid ( Monoid(..), Product, Sum ) +import Data.Semigroup ( Semigroup(..) ) import qualified Prelude as P +import qualified Data.Monoid as P -- Sum: Monoid under addition -- -------------------------- -pattern Sum_ :: Elt a => Exp a -> Exp (Sum a) -pattern Sum_ x = Pattern x -{-# COMPLETE Sum_ #-} +pattern Sum :: IsSum a b => a -> b +pattern Sum x <- (matchSum -> x) + where Sum = buildSum +{-# COMPLETE Sum :: Sum #-} +{-# COMPLETE Sum :: Exp #-} + +class IsSum a b | b -> a where + matchSum :: b -> a + buildSum :: a -> b + +instance IsSum a (Sum a) where + matchSum = P.getSum + buildSum = P.Sum + +instance Elt a => IsSum (Exp a) (Exp (Sum a)) where + matchSum (Pattern x) = x + buildSum x = Pattern x instance Elt a => Elt (Sum a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Sum a) where type Plain (Sum a) = Sum (Plain a) - lift (Sum a) = Sum_ (lift a) + lift (Sum a) = Sum (lift a) instance Elt a => Unlift Exp (Sum (Exp a)) where - unlift (Sum_ a) = Sum a + unlift (Sum a) = Sum a instance Bounded a => P.Bounded (Exp (Sum a)) where - minBound = Sum_ minBound - maxBound = Sum_ maxBound + minBound = Sum minBound + maxBound = Sum maxBound instance Num a => P.Num (Exp (Sum a)) where (+) = lift2 ((+) :: Sum (Exp a) -> Sum (Exp a) -> Sum (Exp a)) @@ -83,45 +100,59 @@ instance Num a => P.Num (Exp (Sum a)) where fromInteger x = lift (P.fromInteger x :: Sum (Exp a)) instance Eq a => Eq (Sum a) where - (==) = lift2 ((==) `on` getSum) - (/=) = lift2 ((/=) `on` getSum) + (==) = lift2 ((==) `on` P.getSum) + (/=) = lift2 ((/=) `on` P.getSum) instance Ord a => Ord (Sum a) where - (<) = lift2 ((<) `on` getSum) - (>) = lift2 ((>) `on` getSum) - (<=) = lift2 ((<=) `on` getSum) - (>=) = lift2 ((>=) `on` getSum) - min x y = Sum_ $ lift2 (min `on` getSum) x y - max x y = Sum_ $ lift2 (max `on` getSum) x y + (<) = lift2 ((<) `on` P.getSum) + (>) = lift2 ((>) `on` P.getSum) + (<=) = lift2 ((<=) `on` P.getSum) + (>=) = lift2 ((>=) `on` P.getSum) + min x y = Sum $ lift2 (min `on` P.getSum) x y + max x y = Sum $ lift2 (max `on` P.getSum) x y instance Num a => Monoid (Exp (Sum a)) where mempty = 0 -- | @since 1.2.0.0 instance Num a => Semigroup (Exp (Sum a)) where - (<>) = (+) - stimes n (Sum_ x) = Sum_ $ P.fromIntegral n * x + (<>) = (+) + stimes n (Sum x) = Sum $ P.fromIntegral n * x -- Product: Monoid under multiplication -- ------------------------------------ -pattern Product_ :: Elt a => Exp a -> Exp (Product a) -pattern Product_ x = Pattern x -{-# COMPLETE Product_ #-} +pattern Product :: IsProduct a b => a -> b +pattern Product x <- (matchProduct -> x) + where Product = buildProduct +{-# COMPLETE Product :: Product #-} +{-# COMPLETE Product :: Exp #-} + +class IsProduct a b | b -> a where + matchProduct :: b -> a + buildProduct :: a -> b + +instance IsProduct a (Product a) where + matchProduct = P.getProduct + buildProduct = P.Product + +instance Elt a => IsProduct (Exp a) (Exp (Product a)) where + matchProduct (Pattern x) = x + buildProduct x = Pattern x instance Elt a => Elt (Product a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Product a) where type Plain (Product a) = Product (Plain a) - lift (Product a) = Product_ (lift a) + lift (Product a) = Product (lift a) instance Elt a => Unlift Exp (Product (Exp a)) where - unlift (Product_ a) = Product a + unlift (Product a) = Product a instance Bounded a => P.Bounded (Exp (Product a)) where - minBound = Product_ minBound - maxBound = Product_ maxBound + minBound = Product minBound + maxBound = Product maxBound instance Num a => P.Num (Exp (Product a)) where (+) = lift2 ((+) :: Product (Exp a) -> Product (Exp a) -> Product (Exp a)) @@ -133,24 +164,24 @@ instance Num a => P.Num (Exp (Product a)) where fromInteger x = lift (P.fromInteger x :: Product (Exp a)) instance Eq a => Eq (Product a) where - (==) = lift2 ((==) `on` getProduct) - (/=) = lift2 ((/=) `on` getProduct) + (==) = lift2 ((==) `on` P.getProduct) + (/=) = lift2 ((/=) `on` P.getProduct) instance Ord a => Ord (Product a) where - (<) = lift2 ((<) `on` getProduct) - (>) = lift2 ((>) `on` getProduct) - (<=) = lift2 ((<=) `on` getProduct) - (>=) = lift2 ((>=) `on` getProduct) - min x y = Product_ $ lift2 (min `on` getProduct) x y - max x y = Product_ $ lift2 (max `on` getProduct) x y + (<) = lift2 ((<) `on` P.getProduct) + (>) = lift2 ((>) `on` P.getProduct) + (<=) = lift2 ((<=) `on` P.getProduct) + (>=) = lift2 ((>=) `on` P.getProduct) + min x y = Product $ lift2 (min `on` P.getProduct) x y + max x y = Product $ lift2 (max `on` P.getProduct) x y instance Num a => Monoid (Exp (Product a)) where mempty = 1 -- | @since 1.2.0.0 instance Num a => Semigroup (Exp (Product a)) where - (<>) = (*) - stimes n (Product_ x) = Product_ $ x ^ (P.fromIntegral n :: Exp Int) + (<>) = (*) + stimes n (Product x) = Product $ x ^ (P.fromIntegral n :: Exp Int) -- Instances for unit and tuples diff --git a/src/Data/Array/Accelerate/Data/Ratio.hs b/src/Data/Array/Accelerate/Data/Ratio.hs index 190317297..22ce4af26 100644 --- a/src/Data/Array/Accelerate/Data/Ratio.hs +++ b/src/Data/Array/Accelerate/Data/Ratio.hs @@ -6,6 +6,7 @@ {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | @@ -31,10 +32,10 @@ module Data.Array.Accelerate.Data.Ratio ( import Data.Array.Accelerate.Language import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Prelude import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Enum import Data.Array.Accelerate.Classes.Eq @@ -75,7 +76,7 @@ reduce x y = -- infixl 7 % (%) :: Integral a => Exp a -> Exp a -> Exp (Ratio a) -x % y = reduce (x * signum y) (abs y) +x % y = reduce (x * signum y) (abs y) infinity :: Integral a => Exp (Ratio a) infinity = 1 :% 0 @@ -109,16 +110,17 @@ instance Integral a => P.Fractional (Exp (Ratio a)) where else y :% x fromRational r = fromInteger (P.numerator r) % fromInteger (P.denominator r) -instance (Integral a, FromIntegral a Int64) => RealFrac (Ratio a) where +instance Integral a => RealFrac (Ratio a) where + type Significand (Ratio a) = a properFraction (x :% y) = let (q,r) = quotRem x y - in (fromIntegral (fromIntegral q :: Exp Int64), r :% y) + in T2 (fromIntegral q) (r :% y) instance (Integral a, ToFloating a b) => ToFloating (Ratio a) b where toFloating (x :% y) = let x' :% y' = reduce x y - in toFloating x' / toFloating y' + in toFloating x' / toFloating y' instance (FromIntegral a b, Integral b) => FromIntegral a (Ratio b) where fromIntegral x = fromIntegral x :% 1 diff --git a/src/Data/Array/Accelerate/Data/Semigroup.hs b/src/Data/Array/Accelerate/Data/Semigroup.hs index 030c51243..82ecdebe5 100644 --- a/src/Data/Array/Accelerate/Data/Semigroup.hs +++ b/src/Data/Array/Accelerate/Data/Semigroup.hs @@ -1,17 +1,18 @@ -{-# LANGUAGE CPP #-} -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} #if __GLASGOW_HASKELL__ >= 806 -{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableInstances #-} #endif -- | -- Module : Data.Array.Accelerate.Data.Semigroup @@ -31,8 +32,8 @@ module Data.Array.Accelerate.Data.Semigroup ( Semigroup(..), - Min(..), pattern Min_, - Max(..), pattern Max_, + Min, pattern Min, + Max, pattern Max, ) where @@ -47,26 +48,41 @@ import Data.Array.Accelerate.Sugar.Elt import Data.Function import Data.Monoid ( Monoid(..) ) -import Data.Semigroup +import Data.Semigroup ( Semigroup(..), Min, Max ) import qualified Prelude as P +import qualified Data.Semigroup as P -pattern Min_ :: Elt a => Exp a -> Exp (Min a) -pattern Min_ x = Pattern x -{-# COMPLETE Min_ #-} +pattern Min :: IsMin a b => a -> b +pattern Min x <- (matchMin -> x) + where Min = buildMin +{-# COMPLETE Min :: Min #-} +{-# COMPLETE Min :: Exp #-} + +class IsMin a b | b -> a where + matchMin :: b -> a + buildMin :: a -> b + +instance IsMin a (Min a) where + matchMin = P.getMin + buildMin = P.Min + +instance Elt a => IsMin (Exp a) (Exp (Min a)) where + matchMin (Pattern x) = x + buildMin x = Pattern x instance Elt a => Elt (Min a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Min a) where type Plain (Min a) = Min (Plain a) - lift (Min a) = Min_ (lift a) + lift (Min a) = Min (lift a) instance Elt a => Unlift Exp (Min (Exp a)) where - unlift (Min_ a) = Min a + unlift (Min a) = Min a instance Bounded a => P.Bounded (Exp (Min a)) where - minBound = lift $ Min (minBound :: Exp a) - maxBound = lift $ Min (maxBound :: Exp a) + minBound = Min minBound + maxBound = Min maxBound instance Num a => P.Num (Exp (Min a)) where (+) = lift2 ((+) :: Min (Exp a) -> Min (Exp a) -> Min (Exp a)) @@ -78,42 +94,56 @@ instance Num a => P.Num (Exp (Min a)) where fromInteger x = lift (P.fromInteger x :: Min (Exp a)) instance Eq a => Eq (Min a) where - (==) = lift2 ((==) `on` getMin) - (/=) = lift2 ((/=) `on` getMin) + (==) = lift2 ((==) `on` P.getMin) + (/=) = lift2 ((/=) `on` P.getMin) instance Ord a => Ord (Min a) where - (<) = lift2 ((<) `on` getMin) - (>) = lift2 ((>) `on` getMin) - (<=) = lift2 ((<=) `on` getMin) - (>=) = lift2 ((>=) `on` getMin) - min x y = lift . Min $ lift2 (min `on` getMin) x y - max x y = lift . Min $ lift2 (max `on` getMin) x y + (<) = lift2 ((<) `on` P.getMin) + (>) = lift2 ((>) `on` P.getMin) + (<=) = lift2 ((<=) `on` P.getMin) + (>=) = lift2 ((>=) `on` P.getMin) + min x y = Min $ lift2 (min `on` P.getMin) x y + max x y = Min $ lift2 (max `on` P.getMin) x y instance Ord a => Semigroup (Exp (Min a)) where - x <> y = lift . Min $ lift2 (min `on` getMin) x y - stimes = stimesIdempotent + x <> y = Min $ lift2 (min `on` P.getMin) x y + stimes = P.stimesIdempotent instance (Ord a, Bounded a) => Monoid (Exp (Min a)) where mempty = maxBound mappend = (<>) -pattern Max_ :: Elt a => Exp a -> Exp (Max a) -pattern Max_ x = Pattern x -{-# COMPLETE Max_ #-} +pattern Max :: IsMax a b => a -> b +pattern Max x <- (matchMax -> x) + where Max = buildMax +{-# COMPLETE Max :: Max #-} +{-# COMPLETE Max :: Exp #-} + +class IsMax a b | b -> a where + matchMax :: b -> a + buildMax :: a -> b + +instance IsMax a (Max a) where + matchMax = P.getMax + buildMax = P.Max + +instance Elt a => IsMax (Exp a) (Exp (Max a)) where + matchMax (Pattern x) = x + buildMax x = Pattern x instance Elt a => Elt (Max a) instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Max a) where type Plain (Max a) = Max (Plain a) - lift (Max a) = Max_ (lift a) + lift (Max a) = Max (lift a) instance Elt a => Unlift Exp (Max (Exp a)) where - unlift (Max_ a) = Max a + unlift (Max a) = Max a instance Bounded a => P.Bounded (Exp (Max a)) where - minBound = Max_ minBound - maxBound = Max_ maxBound + minBound = Max minBound + maxBound = Max maxBound instance Num a => P.Num (Exp (Max a)) where (+) = lift2 ((+) :: Max (Exp a) -> Max (Exp a) -> Max (Exp a)) @@ -125,20 +155,20 @@ instance Num a => P.Num (Exp (Max a)) where fromInteger x = lift (P.fromInteger x :: Max (Exp a)) instance Eq a => Eq (Max a) where - (==) = lift2 ((==) `on` getMax) - (/=) = lift2 ((/=) `on` getMax) + (==) = lift2 ((==) `on` P.getMax) + (/=) = lift2 ((/=) `on` P.getMax) instance Ord a => Ord (Max a) where - (<) = lift2 ((<) `on` getMax) - (>) = lift2 ((>) `on` getMax) - (<=) = lift2 ((<=) `on` getMax) - (>=) = lift2 ((>=) `on` getMax) - min x y = Max_ $ lift2 (min `on` getMax) x y - max x y = Max_ $ lift2 (max `on` getMax) x y + (<) = lift2 ((<) `on` P.getMax) + (>) = lift2 ((>) `on` P.getMax) + (<=) = lift2 ((<=) `on` P.getMax) + (>=) = lift2 ((>=) `on` P.getMax) + min x y = Max $ lift2 (min `on` P.getMax) x y + max x y = Max $ lift2 (max `on` P.getMax) x y instance Ord a => Semigroup (Exp (Max a)) where - x <> y = Max_ $ lift2 (max `on` getMax) x y - stimes = stimesIdempotent + x <> y = Max $ lift2 (max `on` P.getMax) x y + stimes = P.stimesIdempotent instance (Ord a, Bounded a) => Monoid (Exp (Max a)) where mempty = minBound diff --git a/src/Data/Array/Accelerate/Error.hs b/src/Data/Array/Accelerate/Error.hs index 3f00a6b5c..de6741a96 100644 --- a/src/Data/Array/Accelerate/Error.hs +++ b/src/Data/Array/Accelerate/Error.hs @@ -58,7 +58,7 @@ unsafeCheck = withFrozenCallStack $ check Unsafe -- | Throw an error if the index is not in range, otherwise evaluate the result. -- -indexCheck :: HasCallStack => Int -> Int -> a -> a +indexCheck :: (HasCallStack, Integral i) => i -> i -> a -> a indexCheck i n = boundsCheck (bformat ("index out of bounds: i=" % int % ", n=" % int) i n) (i >= 0 && i < n) diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 5b8e6401a..0e3a2e168 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -1,8 +1,12 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ParallelListComp #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} @@ -37,14 +41,16 @@ module Data.Array.Accelerate.Interpreter ( run, run1, runN, -- Internal (hidden) - evalPrim, evalPrimConst, evalCoerceScalar, atraceOp, + evalPrim, evalBitcastScalar, atraceOp, acoerceOp, ) where import Data.Array.Accelerate.AST hiding ( Boundary(..) ) import Data.Array.Accelerate.AST.Environment import Data.Array.Accelerate.AST.Var +import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Array.Data +import Data.Array.Accelerate.Array.Unique import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Elt @@ -53,27 +59,27 @@ import Data.Array.Accelerate.Representation.Slice import Data.Array.Accelerate.Representation.Stencil import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Trafo import Data.Array.Accelerate.Trafo.Delayed ( DelayedOpenAfun, DelayedOpenAcc ) import Data.Array.Accelerate.Trafo.Sharing ( AfunctionR, AfunctionRepr(..), afunctionRepr ) import Data.Array.Accelerate.Type -import Data.Primitive.Vec +import Data.Primitive.Bit as Bit +import Data.Primitive.Vec as Vec +import Data.Primitive.Vec as Prim import qualified Data.Array.Accelerate.AST as AST import qualified Data.Array.Accelerate.Debug.Internal.Flags as Debug import qualified Data.Array.Accelerate.Debug.Internal.Graph as Debug import qualified Data.Array.Accelerate.Debug.Internal.Stats as Debug import qualified Data.Array.Accelerate.Debug.Internal.Timed as Debug +import qualified Data.Array.Accelerate.Interpreter.Arithmetic as A import qualified Data.Array.Accelerate.Smart as Smart import qualified Data.Array.Accelerate.Sugar.Array as Sugar -import qualified Data.Array.Accelerate.Sugar.Elt as Sugar import qualified Data.Array.Accelerate.Trafo.Delayed as AST import Control.DeepSeq import Control.Exception import Control.Monad import Control.Monad.ST -import Data.Bits import Data.Primitive.ByteArray import Data.Primitive.Types import Data.Text.Lazy.Builder @@ -81,8 +87,11 @@ import Formatting import System.IO import System.IO.Unsafe ( unsafePerformIO ) import Unsafe.Coerce -import qualified Data.Text.IO as T import Prelude hiding ( (!!), sum ) +import qualified Data.Text.IO as T + +import GHC.Prim +import GHC.TypeLits -- Program execution @@ -125,13 +134,6 @@ runN f = go eval AfunctionReprBody (Abody b) aenv = unsafePerformIO $ phase "execute" Debug.elapsed (Sugar.toArr . snd <$> evaluate (evalOpenAcc b aenv)) eval _ _aenv _ = error "Two men say they're Jesus; one of them must be wrong" --- -- | Stream a lazily read list of input arrays through the given program, --- -- collecting results as we go --- -- --- streamOut :: Arrays a => Sugar.Seq [a] -> [a] --- streamOut seq = let seq' = convertSeqWith config seq --- in evalDelayedSeq defaultSeqConfig seq' - -- Debugging -- --------- @@ -151,7 +153,7 @@ data Delayed a where Delayed :: ArrayR (Array sh e) -> sh -> (sh -> e) - -> (Int -> e) + -> (INT -> e) -> Delayed (Array sh e) @@ -213,6 +215,8 @@ evalOpenAcc (AST.Manifest pacc) aenv = (TupRpair r1 r2, (a1, a2)) Anil -> (TupRunit, ()) Atrace msg as bs -> unsafePerformIO $ manifest bs <$ atraceOp msg (snd $ manifest as) + Acoerce scale bR acc -> let (TupRsingle (ArrayR shR aR), as) = manifest acc + in (TupRsingle (ArrayR shR bR), acoerceOp scale aR bR as) Apply repr afun acc -> (repr, evalOpenAfun afun aenv $ snd $ manifest acc) Aforeign repr _ afun acc -> (repr, evalOpenAfun afun Empty $ snd $ manifest acc) Acond p acc1 acc2 @@ -225,8 +229,8 @@ evalOpenAcc (AST.Manifest pacc) aenv = p = evalOpenAfun cond aenv f = evalOpenAfun body aenv go !x - | toBool (linearIndexArray (Sugar.eltR @Word8) (p x) 0) = go (f x) - | otherwise = x + | toBool (linearIndexArray (TupRsingle (BitScalarType TypeBit)) (p x) 0) = go (f x) + | otherwise = x Use repr arr -> (TupRsingle repr, arr) Unit tp e -> unitOp tp (evalE e) @@ -371,7 +375,7 @@ zipWithOp tp f (Delayed (ArrayR shr _) shx xs _) (Delayed _ shy ys _) foldOp :: (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) -> WithReprs (Array sh e) foldOp f z (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) = fromFunction' (ArrayR shr tp) sh (\ix -> iter (ShapeRsnoc ShapeRz) ((), n) (\((), i) -> arr (ix, i)) f z) @@ -380,7 +384,7 @@ foldOp f z (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) fold1Op :: HasCallStack => (e -> e -> e) - -> Delayed (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) -> WithReprs (Array sh e) fold1Op f (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) = boundsCheck "empty array" (n > 0) @@ -389,12 +393,12 @@ fold1Op f (Delayed (ArrayR (ShapeRsnoc shr) tp) (sh, n) arr _) foldSegOp :: HasCallStack - => IntegralType i + => SingleIntegralType i -> (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) -> Delayed (Segments i) - -> WithReprs (Array (sh, Int) e) + -> WithReprs (Array (sh, INT) e) foldSegOp itp f z (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) | IntegralDict <- integralDict itp = boundsCheck "empty segment descriptor" (n > 0) @@ -408,11 +412,11 @@ foldSegOp itp f z (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) fold1SegOp :: HasCallStack - => IntegralType i + => SingleIntegralType i -> (e -> e -> e) - -> Delayed (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) -> Delayed (Segments i) - -> WithReprs (Array (sh, Int) e) + -> WithReprs (Array (sh, INT) e) fold1SegOp itp f (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) | IntegralDict <- integralDict itp = boundsCheck "empty segment descriptor" (n > 0) @@ -427,8 +431,8 @@ fold1SegOp itp f (Delayed repr (sh, _) arr _) (Delayed _ ((), n) _ seg) scanl1Op :: forall sh e. HasCallStack => (e -> e -> e) - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e) scanl1Op f (Delayed (ArrayR shr tp) sh ain _) = ( TupRsingle $ ArrayR shr tp , adata `seq` Array sh adata @@ -436,13 +440,13 @@ scanl1Op f (Delayed (ArrayR shr tp) sh ain _) where -- (adata, _) = runArrayData @e $ do - aout <- newArrayData tp (size shr sh) + aout <- newArrayData tp (fromIntegral $ size shr sh) - let write (sz, 0) = writeArrayData tp aout (toIndex shr sh (sz, 0)) (ain (sz, 0)) + let write (sz, 0) = writeArrayData tp aout (fromIntegral $ toIndex shr sh (sz, 0)) (ain (sz, 0)) write (sz, i) = do - x <- readArrayData tp aout (toIndex shr sh (sz, i-1)) + x <- readArrayData tp aout (fromIntegral $ toIndex shr sh (sz, i-1)) let y = ain (sz, i) - writeArrayData tp aout (toIndex shr sh (sz, i)) (f x y) + writeArrayData tp aout (fromIntegral $ toIndex shr sh (sz, i)) (f x y) iter shr sh write (>>) (return ()) return (aout, undefined) @@ -452,8 +456,8 @@ scanlOp :: forall sh e. (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e) scanlOp f z (Delayed (ArrayR shr tp) (sh, n) ain _) = ( TupRsingle $ ArrayR shr tp , adata `seq` Array sh' adata @@ -462,13 +466,13 @@ scanlOp f z (Delayed (ArrayR shr tp) (sh, n) ain _) sh' = (sh, n+1) -- (adata, _) = runArrayData @e $ do - aout <- newArrayData tp (size shr sh') + aout <- newArrayData tp (fromIntegral $ size shr sh') - let write (sz, 0) = writeArrayData tp aout (toIndex shr sh' (sz, 0)) z + let write (sz, 0) = writeArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, 0)) z write (sz, i) = do - x <- readArrayData tp aout (toIndex shr sh' (sz, i-1)) + x <- readArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, i-1)) let y = ain (sz, i-1) - writeArrayData tp aout (toIndex shr sh' (sz, i)) (f x y) + writeArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, i)) (f x y) iter shr sh' write (>>) (return ()) return (aout, undefined) @@ -478,26 +482,26 @@ scanl'Op :: forall sh e. (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e, Array sh e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e, Array sh e) scanl'Op f z (Delayed (ArrayR shr@(ShapeRsnoc shr') tp) (sh, n) ain _) = ( TupRsingle (ArrayR shr tp) `TupRpair` TupRsingle (ArrayR shr' tp) , aout `seq` asum `seq` ( Array (sh, n) aout, Array sh asum ) ) where ((aout, asum), _) = runArrayData @(e, e) $ do - aout <- newArrayData tp (size shr (sh, n)) - asum <- newArrayData tp (size shr' sh) + aout <- newArrayData tp (fromIntegral $ size shr (sh, n)) + asum <- newArrayData tp (fromIntegral $ size shr' sh) let write (sz, 0) - | n == 0 = writeArrayData tp asum (toIndex shr' sh sz) z - | otherwise = writeArrayData tp aout (toIndex shr (sh, n) (sz, 0)) z + | n == 0 = writeArrayData tp asum (fromIntegral $ toIndex shr' sh sz) z + | otherwise = writeArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, 0)) z write (sz, i) = do - x <- readArrayData tp aout (toIndex shr (sh, n) (sz, i-1)) + x <- readArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, i-1)) let y = ain (sz, i-1) if i == n - then writeArrayData tp asum (toIndex shr' sh sz) (f x y) - else writeArrayData tp aout (toIndex shr (sh, n) (sz, i)) (f x y) + then writeArrayData tp asum (fromIntegral $ toIndex shr' sh sz) (f x y) + else writeArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, i)) (f x y) iter shr (sh, n+1) write (>>) (return ()) return ((aout, asum), undefined) @@ -507,8 +511,8 @@ scanrOp :: forall sh e. (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e) scanrOp f z (Delayed (ArrayR shr tp) (sz, n) ain _) = ( TupRsingle (ArrayR shr tp) , adata `seq` Array sh' adata @@ -517,13 +521,13 @@ scanrOp f z (Delayed (ArrayR shr tp) (sz, n) ain _) sh' = (sz, n+1) -- (adata, _) = runArrayData @e $ do - aout <- newArrayData tp (size shr sh') + aout <- newArrayData tp (fromIntegral $ size shr sh') - let write (sz, 0) = writeArrayData tp aout (toIndex shr sh' (sz, n)) z + let write (sz, 0) = writeArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, n)) z write (sz, i) = do let x = ain (sz, n-i) - y <- readArrayData tp aout (toIndex shr sh' (sz, n-i+1)) - writeArrayData tp aout (toIndex shr sh' (sz, n-i)) (f x y) + y <- readArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, n-i+1)) + writeArrayData tp aout (fromIntegral $ toIndex shr sh' (sz, n-i)) (f x y) iter shr sh' write (>>) (return ()) return (aout, undefined) @@ -532,21 +536,21 @@ scanrOp f z (Delayed (ArrayR shr tp) (sz, n) ain _) scanr1Op :: forall sh e. HasCallStack => (e -> e -> e) - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e) scanr1Op f (Delayed (ArrayR shr tp) sh@(_, n) ain _) = ( TupRsingle $ ArrayR shr tp , adata `seq` Array sh adata ) where (adata, _) = runArrayData @e $ do - aout <- newArrayData tp (size shr sh) + aout <- newArrayData tp (fromIntegral $ size shr sh) - let write (sz, 0) = writeArrayData tp aout (toIndex shr sh (sz, n-1)) (ain (sz, n-1)) + let write (sz, 0) = writeArrayData tp aout (fromIntegral $ toIndex shr sh (sz, n-1)) (ain (sz, n-1)) write (sz, i) = do let x = ain (sz, n-i-1) - y <- readArrayData tp aout (toIndex shr sh (sz, n-i)) - writeArrayData tp aout (toIndex shr sh (sz, n-i-1)) (f x y) + y <- readArrayData tp aout (fromIntegral $ toIndex shr sh (sz, n-i)) + writeArrayData tp aout (fromIntegral $ toIndex shr sh (sz, n-i-1)) (f x y) iter shr sh write (>>) (return ()) return (aout, undefined) @@ -556,27 +560,27 @@ scanr'Op :: forall sh e. (e -> e -> e) -> e - -> Delayed (Array (sh, Int) e) - -> WithReprs (Array (sh, Int) e, Array sh e) + -> Delayed (Array (sh, INT) e) + -> WithReprs (Array (sh, INT) e, Array sh e) scanr'Op f z (Delayed (ArrayR shr@(ShapeRsnoc shr') tp) (sh, n) ain _) = ( TupRsingle (ArrayR shr tp) `TupRpair` TupRsingle (ArrayR shr' tp) , aout `seq` asum `seq` ( Array (sh, n) aout, Array sh asum ) ) where ((aout, asum), _) = runArrayData @(e, e) $ do - aout <- newArrayData tp (size shr (sh, n)) - asum <- newArrayData tp (size shr' sh) + aout <- newArrayData tp (fromIntegral $ size shr (sh, n)) + asum <- newArrayData tp (fromIntegral $ size shr' sh) let write (sz, 0) - | n == 0 = writeArrayData tp asum (toIndex shr' sh sz) z - | otherwise = writeArrayData tp aout (toIndex shr (sh, n) (sz, n-1)) z + | n == 0 = writeArrayData tp asum (fromIntegral $ toIndex shr' sh sz) z + | otherwise = writeArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, n-1)) z write (sz, i) = do let x = ain (sz, n-i) - y <- readArrayData tp aout (toIndex shr (sh, n) (sz, n-i)) + y <- readArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, n-i)) if i == n - then writeArrayData tp asum (toIndex shr' sh sz) (f x y) - else writeArrayData tp aout (toIndex shr (sh, n) (sz, n-i-1)) (f x y) + then writeArrayData tp asum (fromIntegral $ toIndex shr' sh sz) (f x y) + else writeArrayData tp aout (fromIntegral $ toIndex shr (sh, n) (sz, n-i-1)) (f x y) iter shr (sh, n+1) write (>>) (return ()) return ((aout, asum), undefined) @@ -596,14 +600,14 @@ permuteOp f (TupRsingle (ArrayR shr' _), def@(Array _ adef)) p (Delayed (ArrayR n' = size shr' sh' -- (adata, _) = runArrayData @e $ do - aout <- newArrayData tp n' + aout <- newArrayData tp (fromIntegral n') let -- initialise array with default values init i | i >= n' = return () | otherwise = do - x <- readArrayData tp adef i - writeArrayData tp aout i x + x <- readArrayData tp adef (fromIntegral i) + writeArrayData tp aout (fromIntegral i) x init (i+1) -- project each element onto the destination array and update @@ -615,8 +619,8 @@ permuteOp f (TupRsingle (ArrayR shr' _), def@(Array _ adef)) p (Delayed (ArrayR j = toIndex shr' sh' dst x = ain i -- - y <- readArrayData tp aout j - writeArrayData tp aout j (f x y) + y <- readArrayData tp aout (fromIntegral j) + writeArrayData tp aout (fromIntegral j) (f x y) _ -> internalError "unexpected tag" init 0 @@ -782,13 +786,13 @@ stencilAccess stencil = goR (stencilShapeR stencil) stencil -- Add a left-most component to an index -- - cons :: ShapeR sh -> Int -> sh -> (sh, Int) + cons :: ShapeR sh -> INT -> sh -> (sh, INT) cons ShapeRz ix () = ((), ix) cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) -- Remove the left-most index of an index, and return the remainder -- - uncons :: ShapeR sh -> (sh, Int) -> (Int, sh) + uncons :: ShapeR sh -> (sh, INT) -> (INT, sh) uncons ShapeRz ((), v) = (v, ()) uncons (ShapeRsnoc shr) (v1, v2) = let (i, v1') = uncons shr v1 in (i, (v1', v2)) @@ -838,17 +842,6 @@ bounded shr bnd (Delayed _ sh f _) ix = _ -> internalError "unexpected boundary condition" | otherwise = iz --- toSeqOp :: forall slix sl dim co e proxy. (Elt slix, Shape sl, Shape dim, Elt e) --- => SliceIndex (EltRepr slix) --- (EltRepr sl) --- co --- (EltRepr dim) --- -> proxy slix --- -> Array dim e --- -> [Array sl e] --- toSeqOp sliceIndex _ arr = map (sliceOp sliceIndex arr :: slix -> Array sl e) --- (enumSlices sliceIndex (shape arr)) - -- Stencil boundary conditions -- --------------------------- @@ -932,17 +925,18 @@ evalOpenExp pexp env aenv = in evalOpenExp exp2 env' aenv Evar (Var _ ix) -> prj ix env Const _ c -> c - Undef tp -> undefElt (TupRsingle tp) - PrimConst c -> evalPrimConst c + Undef eR -> undefElt (TupRsingle eR) PrimApp f x -> evalPrim f (evalE x) Nil -> () Pair e1 e2 -> let !x1 = evalE e1 !x2 = evalE e2 in (x1, x2) - VecPack vecR e -> pack vecR $! evalE e - VecUnpack vecR e -> unpack vecR $! evalE e - IndexSlice slice slix sh -> restrict slice (evalE slix) - (evalE sh) + Extract vR iR v i -> evalExtract vR iR (evalE v) (evalE i) + Insert vR iR v i x -> evalInsert vR iR (evalE v) (evalE i) (evalE x) + Shuffle rR iR x y i -> let TupRsingle eR = expType x + in evalShuffle eR rR iR (evalE x) (evalE y) (evalE i) + Select eR m x y -> evalSelect eR (evalE m) (evalE x) (evalE y) + IndexSlice slice slix sh -> restrict slice (evalE slix) (evalE sh) where restrict :: SliceIndex slix sl co sh -> slix -> sh -> sl restrict SliceNil () () = () @@ -952,8 +946,7 @@ evalOpenExp pexp env aenv = restrict (SliceFixed sliceIdx) (slx, _i) (sl, _sz) = restrict sliceIdx slx sl - IndexFull slice slix sh -> extend slice (evalE slix) - (evalE sh) + IndexFull slice slix sh -> extend slice (evalE slix) (evalE sh) where extend :: SliceIndex slix sl co sh -> slix -> sl -> sh extend SliceNil () () = () @@ -966,11 +959,12 @@ evalOpenExp pexp env aenv = ToIndex shr sh ix -> toIndex shr (evalE sh) (evalE ix) FromIndex shr sh ix -> fromIndex shr (evalE sh) (evalE ix) - Case e rhs def -> evalE (caseof (evalE e) rhs) + Case tagR e rhs def -> evalE (caseof tagR (evalE e) rhs) where - caseof :: TAG -> [(TAG, OpenExp env aenv t)] -> OpenExp env aenv t - caseof tag = go + caseof :: forall tag. TagType tag -> tag -> [(tag, OpenExp env aenv t)] -> OpenExp env aenv t + caseof tagR tag | TagDict <- tagDict tagR = go where + go :: Eq tag => [(tag, OpenExp env aenv t)] -> OpenExp env aenv t go ((t,c):cs) | tag == t = c | otherwise = go cs @@ -998,16 +992,339 @@ evalOpenExp pexp env aenv = Shape acc -> shape $ snd $ evalA acc ShapeSize shr sh -> size shr (evalE sh) Foreign _ _ f e -> evalOpenFun f Empty Empty $ evalE e - Coerce t1 t2 e -> evalCoerceScalar t1 t2 (evalE e) + Bitcast t1 t2 e -> evalBitcastScalar t1 t2 (evalE e) -- Coercions -- --------- +acoerceOp + :: HasCallStack + => RescaleFactor + -> TypeR a + -> TypeR b + -> Array (sh, INT) a + -> Array (sh, INT) b +acoerceOp scale aR bR (Array (sz,sh) adata) = arr' + where + arr' = Array (sz,sh') adata' + sh' = case compare scale 0 of + EQ -> sh + GT -> sh * scale + LT -> let (q,r) = quotRem sh (negate scale) + in boundsCheck "shape mismatch" (r == 0) q + adata' = acoerce aR bR adata + + acoerce :: TypeR a -> TypeR b -> ArrayData a -> ArrayData b + acoerce TupRunit TupRunit () = () + acoerce (TupRpair TupRunit aR) bR ((), ad) | Just Refl <- matchTypeR aR bR = ad + acoerce (TupRpair aR TupRunit) bR (ad, ()) | Just Refl <- matchTypeR aR bR = ad + acoerce aR (TupRpair TupRunit bR) ad | Just Refl <- matchTypeR aR bR = ((), ad) + acoerce aR (TupRpair bR TupRunit) ad | Just Refl <- matchTypeR aR bR = (ad, ()) + acoerce (TupRpair aR1 aR2) (TupRpair bR1 bR2) (a1, a2) = (acoerce aR1 bR1 a1, acoerce aR2 bR2 a2) + acoerce (TupRsingle aR) (TupRsingle bR) ad = scalar aR bR ad + acoerce _ _ _ = internalError "missing cases for class Acoerce" + + scalar :: ScalarType a -> ScalarType b -> ArrayData a -> ArrayData b + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + num :: NumType a -> ScalarType b -> ArrayData a -> ArrayData b + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + bit :: forall a b. BitType a -> ScalarType b -> ArrayData a -> ArrayData b + bit TypeBit = scalar' @a + bit TypeMask{} = scalar' @a + + integral :: forall a b. IntegralType a -> ScalarType b -> ArrayData a -> ArrayData b + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType a -> ScalarType b -> ArrayData a -> ArrayData b + single TypeInt8 = scalar' @a + single TypeInt16 = scalar' @a + single TypeInt32 = scalar' @a + single TypeInt64 = scalar' @a + single TypeInt128 = scalar' @a + single TypeWord8 = scalar' @a + single TypeWord16 = scalar' @a + single TypeWord32 = scalar' @a + single TypeWord64 = scalar' @a + single TypeWord128 = scalar' @a + + vector :: (KnownNat n, a ~ Vec n c) => Proxy# n -> SingleIntegralType c -> ScalarType b -> ArrayData (Vec n c) -> ArrayData b + vector _ TypeInt8 = scalar' @a + vector _ TypeInt16 = scalar' @a + vector _ TypeInt32 = scalar' @a + vector _ TypeInt64 = scalar' @a + vector _ TypeInt128 = scalar' @a + vector _ TypeWord8 = scalar' @a + vector _ TypeWord16 = scalar' @a + vector _ TypeWord32 = scalar' @a + vector _ TypeWord64 = scalar' @a + vector _ TypeWord128 = scalar' @a + + floating :: forall a b. FloatingType a -> ScalarType b -> ArrayData a -> ArrayData b + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType a -> ScalarType b -> ArrayData a -> ArrayData b + single TypeFloat16 = scalar' @a + single TypeFloat32 = scalar' @a + single TypeFloat64 = scalar' @a + single TypeFloat128 = scalar' @a + + vector :: (KnownNat n, a ~ Vec n c) => Proxy# n -> SingleFloatingType c -> ScalarType b -> ArrayData (Vec n c) -> ArrayData b + vector _ TypeFloat16 = scalar' @a + vector _ TypeFloat32 = scalar' @a + vector _ TypeFloat64 = scalar' @a + vector _ TypeFloat128 = scalar' @a + + scalar' :: forall a b. ScalarType b -> ScalarArrayData a -> ArrayData b + scalar' (NumScalarType t) = num' @a t + scalar' (BitScalarType t) = bit' @a t + + num' :: forall a b. NumType b -> ScalarArrayData a -> ArrayData b + num' (IntegralNumType t) = integral' @a t + num' (FloatingNumType t) = floating' @a t + + bit' :: forall a b. BitType b -> ScalarArrayData a -> ArrayData b + bit' TypeBit = castUniqueArray + bit' TypeMask{} = castUniqueArray + + integral' :: forall a b. IntegralType b -> ScalarArrayData a -> ArrayData b + integral' = \case + SingleIntegralType t -> single' t + VectorIntegralType n t -> vector' n t + where + single' :: SingleIntegralType b -> ScalarArrayData a -> ArrayData b + single' TypeInt8 = castUniqueArray + single' TypeInt16 = castUniqueArray + single' TypeInt32 = castUniqueArray + single' TypeInt64 = castUniqueArray + single' TypeInt128 = castUniqueArray + single' TypeWord8 = castUniqueArray + single' TypeWord16 = castUniqueArray + single' TypeWord32 = castUniqueArray + single' TypeWord64 = castUniqueArray + single' TypeWord128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleIntegralType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeInt8 = castUniqueArray + vector' _ TypeInt16 = castUniqueArray + vector' _ TypeInt32 = castUniqueArray + vector' _ TypeInt64 = castUniqueArray + vector' _ TypeInt128 = castUniqueArray + vector' _ TypeWord8 = castUniqueArray + vector' _ TypeWord16 = castUniqueArray + vector' _ TypeWord32 = castUniqueArray + vector' _ TypeWord64 = castUniqueArray + vector' _ TypeWord128 = castUniqueArray + + floating' :: forall a b. FloatingType b -> ScalarArrayData a -> ArrayData b + floating' = \case + SingleFloatingType t -> single' t + VectorFloatingType n t -> vector' n t + where + single' :: SingleFloatingType b -> ScalarArrayData a -> ArrayData b + single' TypeFloat16 = castUniqueArray + single' TypeFloat32 = castUniqueArray + single' TypeFloat64 = castUniqueArray + single' TypeFloat128 = castUniqueArray + + vector' :: KnownNat n => Proxy# n -> SingleFloatingType c -> ScalarArrayData a -> ArrayData (Vec n c) + vector' _ TypeFloat16 = castUniqueArray + vector' _ TypeFloat32 = castUniqueArray + vector' _ TypeFloat64 = castUniqueArray + vector' _ TypeFloat128 = castUniqueArray + + -- Coercion between two scalar types. We require that the size of the source and -- destination values are equal (this is not checked at this point). -- -evalCoerceScalar :: ScalarType a -> ScalarType b -> a -> b +evalBitcastScalar :: ScalarType a -> ScalarType b -> a -> b +evalBitcastScalar = scalar + where + scalar :: ScalarType a -> ScalarType b -> a -> b + scalar (NumScalarType a) = num a + scalar (BitScalarType a) = bit a + + bit :: BitType a -> ScalarType b -> a -> b + bit TypeBit = \case + BitScalarType TypeBit -> id + _ -> internalError "evalBitcastScalar @Bit" + bit (TypeMask _) = \case + NumScalarType b -> num' b + BitScalarType b -> bit' b + where + bit' :: BitType b -> Vec n Bit -> b + bit' TypeMask{} = unsafeCoerce + bit' TypeBit = internalError "evalBitcastScalar @Bit" + + num' :: NumType b -> Vec n Bit -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> Vec n Bit -> b + integral' (VectorIntegralType _ _) = unsafeCoerce + integral' (SingleIntegralType b) + | IntegralDict <- integralDict b + = peek + + floating' :: FloatingType b -> Vec n Bit -> b + floating' (VectorFloatingType _ _) = unsafeCoerce + floating' (SingleFloatingType b) + | FloatingDict <- floatingDict b + = peek + + num :: NumType a -> ScalarType b -> a -> b + num (IntegralNumType a) = integral a + num (FloatingNumType t) = floating t + + integral :: IntegralType a -> ScalarType b -> a -> b + integral (SingleIntegralType a) = \case + NumScalarType b -> num' b a + BitScalarType b -> bit' b a + where + bit' :: BitType b -> SingleIntegralType a -> a -> b + bit' TypeBit _ = unsafeCoerce + bit' TypeMask{} a + | IntegralDict <- integralDict a + = poke + + num' :: NumType b -> SingleIntegralType a -> a -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> SingleIntegralType a -> a -> b + integral' (SingleIntegralType _) _ = unsafeCoerce + integral' (VectorIntegralType _ b) a + | IntegralDict <- integralDict a + = case b of + TypeInt8 -> poke + TypeInt16 -> poke + TypeInt32 -> poke + TypeInt64 -> poke + TypeInt128 -> poke + TypeWord8 -> poke + TypeWord16 -> poke + TypeWord32 -> poke + TypeWord64 -> poke + TypeWord128 -> poke + + floating' :: FloatingType b -> SingleIntegralType a -> a -> b + floating' (SingleFloatingType _) _ = unsafeCoerce + floating' (VectorFloatingType _ b) a + | IntegralDict <- integralDict a + = case b of + TypeFloat16 -> poke + TypeFloat32 -> poke + TypeFloat64 -> poke + TypeFloat128 -> poke + + integral (VectorIntegralType _ a) = \case + NumScalarType b -> num' b a + BitScalarType b -> bit' b a + where + bit' :: BitType b -> SingleIntegralType a -> Vec n a -> b + bit' TypeBit _ = unsafeCoerce + bit' TypeMask{} _ = unsafeCoerce + + num' :: NumType b -> SingleIntegralType a -> Vec n a -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> SingleIntegralType a -> Vec n a -> b + integral' (VectorIntegralType _ _) _ = unsafeCoerce + integral' (SingleIntegralType b) _ + | IntegralDict <- integralDict b + = peek + + floating' :: FloatingType b -> SingleIntegralType a -> Vec n a -> b + floating' (VectorFloatingType _ _) _ = unsafeCoerce + floating' (SingleFloatingType b) _ + | FloatingDict <- floatingDict b + = peek + + floating :: FloatingType a -> ScalarType b -> a -> b + floating (SingleFloatingType a) = \case + NumScalarType b -> num' b a + BitScalarType b -> bit' b a + where + bit' :: BitType b -> SingleFloatingType a -> a -> b + bit' TypeBit _ = unsafeCoerce + bit' TypeMask{} _ = unsafeCoerce + + num' :: NumType b -> SingleFloatingType a -> a -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> SingleFloatingType a -> a -> b + integral' (SingleIntegralType _) _ = unsafeCoerce + integral' (VectorIntegralType _ b) a + | FloatingDict <- floatingDict a + = case b of + TypeInt8 -> poke + TypeInt16 -> poke + TypeInt32 -> poke + TypeInt64 -> poke + TypeInt128 -> poke + TypeWord8 -> poke + TypeWord16 -> poke + TypeWord32 -> poke + TypeWord64 -> poke + TypeWord128 -> poke + + floating' :: FloatingType b -> SingleFloatingType a -> a -> b + floating' (SingleFloatingType _) _ = unsafeCoerce + floating' (VectorFloatingType _ b) a + | FloatingDict <- floatingDict a + = case b of + TypeFloat16 -> poke + TypeFloat32 -> poke + TypeFloat64 -> poke + TypeFloat128 -> poke + + floating (VectorFloatingType _ a) = \case + NumScalarType b -> num' b a + BitScalarType b -> bit' b a + where + bit' :: BitType b -> SingleFloatingType a -> Vec n a -> b + bit' TypeBit _ = unsafeCoerce + bit' TypeMask{} _ = unsafeCoerce + + num' :: NumType b -> SingleFloatingType a -> Vec n a -> b + num' (IntegralNumType b) = integral' b + num' (FloatingNumType b) = floating' b + + integral' :: IntegralType b -> SingleFloatingType a -> Vec n a -> b + integral' (VectorIntegralType _ _) _ = unsafeCoerce + integral' (SingleIntegralType b) _ + | IntegralDict <- integralDict b + = peek + + floating' :: FloatingType b -> SingleFloatingType a -> Vec n a -> b + floating' (VectorFloatingType _ _) _ = unsafeCoerce + floating' (SingleFloatingType b) _ + | FloatingDict <- floatingDict b + = peek + + {-# INLINE poke #-} + poke :: forall a b n. Prim a => a -> Vec n b + poke x = runST $ do + mba <- newByteArray (sizeOf (undefined::a)) + writeByteArray mba 0 x + ByteArray ba# <- unsafeFreezeByteArray mba + return $ Vec ba# + + {-# INLINE peek #-} + peek :: Prim b => Vec n a -> b + peek (Vec ba#) = indexByteArray (ByteArray ba#) 0 + +{-- evalCoerceScalar SingleScalarType{} SingleScalarType{} a = unsafeCoerce a evalCoerceScalar VectorScalarType{} VectorScalarType{} a = unsafeCoerce a -- XXX: or just unpack/repack the (Vec ba#) evalCoerceScalar (SingleScalarType ta) VectorScalarType{} a = vector ta a @@ -1073,797 +1390,329 @@ evalCoerceScalar VectorScalarType{} (SingleScalarType tb) a = scalar tb a {-# INLINE peek #-} peek :: Prim a => Vec n b -> a peek (Vec ba#) = indexByteArray (ByteArray ba#) 0 - +--} -- Scalar primitives -- ----------------- -evalPrimConst :: PrimConst a -> a -evalPrimConst (PrimMinBound ty) = evalMinBound ty -evalPrimConst (PrimMaxBound ty) = evalMaxBound ty -evalPrimConst (PrimPi ty) = evalPi ty - evalPrim :: PrimFun (a -> r) -> (a -> r) -evalPrim (PrimAdd ty) = evalAdd ty -evalPrim (PrimSub ty) = evalSub ty -evalPrim (PrimMul ty) = evalMul ty -evalPrim (PrimNeg ty) = evalNeg ty -evalPrim (PrimAbs ty) = evalAbs ty -evalPrim (PrimSig ty) = evalSig ty -evalPrim (PrimQuot ty) = evalQuot ty -evalPrim (PrimRem ty) = evalRem ty -evalPrim (PrimQuotRem ty) = evalQuotRem ty -evalPrim (PrimIDiv ty) = evalIDiv ty -evalPrim (PrimMod ty) = evalMod ty -evalPrim (PrimDivMod ty) = evalDivMod ty -evalPrim (PrimBAnd ty) = evalBAnd ty -evalPrim (PrimBOr ty) = evalBOr ty -evalPrim (PrimBXor ty) = evalBXor ty -evalPrim (PrimBNot ty) = evalBNot ty -evalPrim (PrimBShiftL ty) = evalBShiftL ty -evalPrim (PrimBShiftR ty) = evalBShiftR ty -evalPrim (PrimBRotateL ty) = evalBRotateL ty -evalPrim (PrimBRotateR ty) = evalBRotateR ty -evalPrim (PrimPopCount ty) = evalPopCount ty -evalPrim (PrimCountLeadingZeros ty) = evalCountLeadingZeros ty -evalPrim (PrimCountTrailingZeros ty) = evalCountTrailingZeros ty -evalPrim (PrimFDiv ty) = evalFDiv ty -evalPrim (PrimRecip ty) = evalRecip ty -evalPrim (PrimSin ty) = evalSin ty -evalPrim (PrimCos ty) = evalCos ty -evalPrim (PrimTan ty) = evalTan ty -evalPrim (PrimAsin ty) = evalAsin ty -evalPrim (PrimAcos ty) = evalAcos ty -evalPrim (PrimAtan ty) = evalAtan ty -evalPrim (PrimSinh ty) = evalSinh ty -evalPrim (PrimCosh ty) = evalCosh ty -evalPrim (PrimTanh ty) = evalTanh ty -evalPrim (PrimAsinh ty) = evalAsinh ty -evalPrim (PrimAcosh ty) = evalAcosh ty -evalPrim (PrimAtanh ty) = evalAtanh ty -evalPrim (PrimExpFloating ty) = evalExpFloating ty -evalPrim (PrimSqrt ty) = evalSqrt ty -evalPrim (PrimLog ty) = evalLog ty -evalPrim (PrimFPow ty) = evalFPow ty -evalPrim (PrimLogBase ty) = evalLogBase ty -evalPrim (PrimTruncate ta tb) = evalTruncate ta tb -evalPrim (PrimRound ta tb) = evalRound ta tb -evalPrim (PrimFloor ta tb) = evalFloor ta tb -evalPrim (PrimCeiling ta tb) = evalCeiling ta tb -evalPrim (PrimAtan2 ty) = evalAtan2 ty -evalPrim (PrimIsNaN ty) = evalIsNaN ty -evalPrim (PrimIsInfinite ty) = evalIsInfinite ty -evalPrim (PrimLt ty) = evalLt ty -evalPrim (PrimGt ty) = evalGt ty -evalPrim (PrimLtEq ty) = evalLtEq ty -evalPrim (PrimGtEq ty) = evalGtEq ty -evalPrim (PrimEq ty) = evalEq ty -evalPrim (PrimNEq ty) = evalNEq ty -evalPrim (PrimMax ty) = evalMax ty -evalPrim (PrimMin ty) = evalMin ty -evalPrim PrimLAnd = evalLAnd -evalPrim PrimLOr = evalLOr -evalPrim PrimLNot = evalLNot -evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb -evalPrim (PrimToFloating ta tb) = evalToFloating ta tb - - --- Implementation of scalar primitives --- ----------------------------------- - -toBool :: PrimBool -> Bool -toBool 0 = False -toBool _ = True - -fromBool :: Bool -> PrimBool -fromBool False = 0 -fromBool True = 1 - -evalLAnd :: (PrimBool, PrimBool) -> PrimBool -evalLAnd (x, y) = fromBool (toBool x && toBool y) - -evalLOr :: (PrimBool, PrimBool) -> PrimBool -evalLOr (x, y) = fromBool (toBool x || toBool y) - -evalLNot :: PrimBool -> PrimBool -evalLNot = fromBool . not . toBool - -evalFromIntegral :: IntegralType a -> NumType b -> a -> b -evalFromIntegral ta (IntegralNumType tb) - | IntegralDict <- integralDict ta - , IntegralDict <- integralDict tb - = fromIntegral - -evalFromIntegral ta (FloatingNumType tb) - | IntegralDict <- integralDict ta - , FloatingDict <- floatingDict tb - = fromIntegral - -evalToFloating :: NumType a -> FloatingType b -> a -> b -evalToFloating (IntegralNumType ta) tb - | IntegralDict <- integralDict ta - , FloatingDict <- floatingDict tb - = realToFrac - -evalToFloating (FloatingNumType ta) tb - | FloatingDict <- floatingDict ta - , FloatingDict <- floatingDict tb - = realToFrac - - --- Extract methods from reified dictionaries --- - --- Constant methods of Bounded --- - -evalMinBound :: BoundedType a -> a -evalMinBound (IntegralBoundedType ty) - | IntegralDict <- integralDict ty - = minBound - -evalMaxBound :: BoundedType a -> a -evalMaxBound (IntegralBoundedType ty) - | IntegralDict <- integralDict ty - = maxBound - --- Constant method of floating --- - -evalPi :: FloatingType a -> a -evalPi ty | FloatingDict <- floatingDict ty = pi - -evalSin :: FloatingType a -> (a -> a) -evalSin ty | FloatingDict <- floatingDict ty = sin - -evalCos :: FloatingType a -> (a -> a) -evalCos ty | FloatingDict <- floatingDict ty = cos - -evalTan :: FloatingType a -> (a -> a) -evalTan ty | FloatingDict <- floatingDict ty = tan - -evalAsin :: FloatingType a -> (a -> a) -evalAsin ty | FloatingDict <- floatingDict ty = asin - -evalAcos :: FloatingType a -> (a -> a) -evalAcos ty | FloatingDict <- floatingDict ty = acos - -evalAtan :: FloatingType a -> (a -> a) -evalAtan ty | FloatingDict <- floatingDict ty = atan - -evalSinh :: FloatingType a -> (a -> a) -evalSinh ty | FloatingDict <- floatingDict ty = sinh - -evalCosh :: FloatingType a -> (a -> a) -evalCosh ty | FloatingDict <- floatingDict ty = cosh - -evalTanh :: FloatingType a -> (a -> a) -evalTanh ty | FloatingDict <- floatingDict ty = tanh - -evalAsinh :: FloatingType a -> (a -> a) -evalAsinh ty | FloatingDict <- floatingDict ty = asinh - -evalAcosh :: FloatingType a -> (a -> a) -evalAcosh ty | FloatingDict <- floatingDict ty = acosh - -evalAtanh :: FloatingType a -> (a -> a) -evalAtanh ty | FloatingDict <- floatingDict ty = atanh - -evalExpFloating :: FloatingType a -> (a -> a) -evalExpFloating ty | FloatingDict <- floatingDict ty = exp - -evalSqrt :: FloatingType a -> (a -> a) -evalSqrt ty | FloatingDict <- floatingDict ty = sqrt - -evalLog :: FloatingType a -> (a -> a) -evalLog ty | FloatingDict <- floatingDict ty = log - -evalFPow :: FloatingType a -> ((a, a) -> a) -evalFPow ty | FloatingDict <- floatingDict ty = uncurry (**) - -evalLogBase :: FloatingType a -> ((a, a) -> a) -evalLogBase ty | FloatingDict <- floatingDict ty = uncurry logBase - -evalTruncate :: FloatingType a -> IntegralType b -> (a -> b) -evalTruncate ta tb - | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb - = truncate - -evalRound :: FloatingType a -> IntegralType b -> (a -> b) -evalRound ta tb - | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb - = round - -evalFloor :: FloatingType a -> IntegralType b -> (a -> b) -evalFloor ta tb - | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb - = floor - -evalCeiling :: FloatingType a -> IntegralType b -> (a -> b) -evalCeiling ta tb - | FloatingDict <- floatingDict ta - , IntegralDict <- integralDict tb - = ceiling - -evalAtan2 :: FloatingType a -> ((a, a) -> a) -evalAtan2 ty | FloatingDict <- floatingDict ty = uncurry atan2 - -evalIsNaN :: FloatingType a -> (a -> PrimBool) -evalIsNaN ty | FloatingDict <- floatingDict ty = fromBool . isNaN - -evalIsInfinite :: FloatingType a -> (a -> PrimBool) -evalIsInfinite ty | FloatingDict <- floatingDict ty = fromBool . isInfinite - - --- Methods of Num --- - -evalAdd :: NumType a -> ((a, a) -> a) -evalAdd (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (+) -evalAdd (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (+) - -evalSub :: NumType a -> ((a, a) -> a) -evalSub (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (-) -evalSub (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (-) - -evalMul :: NumType a -> ((a, a) -> a) -evalMul (IntegralNumType ty) | IntegralDict <- integralDict ty = uncurry (*) -evalMul (FloatingNumType ty) | FloatingDict <- floatingDict ty = uncurry (*) - -evalNeg :: NumType a -> (a -> a) -evalNeg (IntegralNumType ty) | IntegralDict <- integralDict ty = negate -evalNeg (FloatingNumType ty) | FloatingDict <- floatingDict ty = negate - -evalAbs :: NumType a -> (a -> a) -evalAbs (IntegralNumType ty) | IntegralDict <- integralDict ty = abs -evalAbs (FloatingNumType ty) | FloatingDict <- floatingDict ty = abs - -evalSig :: NumType a -> (a -> a) -evalSig (IntegralNumType ty) | IntegralDict <- integralDict ty = signum -evalSig (FloatingNumType ty) | FloatingDict <- floatingDict ty = signum - -evalQuot :: IntegralType a -> ((a, a) -> a) -evalQuot ty | IntegralDict <- integralDict ty = uncurry quot - -evalRem :: IntegralType a -> ((a, a) -> a) -evalRem ty | IntegralDict <- integralDict ty = uncurry rem - -evalQuotRem :: IntegralType a -> ((a, a) -> (a, a)) -evalQuotRem ty | IntegralDict <- integralDict ty = uncurry quotRem - -evalIDiv :: IntegralType a -> ((a, a) -> a) -evalIDiv ty | IntegralDict <- integralDict ty = uncurry div - -evalMod :: IntegralType a -> ((a, a) -> a) -evalMod ty | IntegralDict <- integralDict ty = uncurry mod - -evalDivMod :: IntegralType a -> ((a, a) -> (a, a)) -evalDivMod ty | IntegralDict <- integralDict ty = uncurry divMod - -evalBAnd :: IntegralType a -> ((a, a) -> a) -evalBAnd ty | IntegralDict <- integralDict ty = uncurry (.&.) - -evalBOr :: IntegralType a -> ((a, a) -> a) -evalBOr ty | IntegralDict <- integralDict ty = uncurry (.|.) - -evalBXor :: IntegralType a -> ((a, a) -> a) -evalBXor ty | IntegralDict <- integralDict ty = uncurry xor - -evalBNot :: IntegralType a -> (a -> a) -evalBNot ty | IntegralDict <- integralDict ty = complement - -evalBShiftL :: IntegralType a -> ((a, Int) -> a) -evalBShiftL ty | IntegralDict <- integralDict ty = uncurry shiftL - -evalBShiftR :: IntegralType a -> ((a, Int) -> a) -evalBShiftR ty | IntegralDict <- integralDict ty = uncurry shiftR - -evalBRotateL :: IntegralType a -> ((a, Int) -> a) -evalBRotateL ty | IntegralDict <- integralDict ty = uncurry rotateL - -evalBRotateR :: IntegralType a -> ((a, Int) -> a) -evalBRotateR ty | IntegralDict <- integralDict ty = uncurry rotateR - -evalPopCount :: IntegralType a -> (a -> Int) -evalPopCount ty | IntegralDict <- integralDict ty = popCount - -evalCountLeadingZeros :: IntegralType a -> (a -> Int) -evalCountLeadingZeros ty | IntegralDict <- integralDict ty = countLeadingZeros - -evalCountTrailingZeros :: IntegralType a -> (a -> Int) -evalCountTrailingZeros ty | IntegralDict <- integralDict ty = countTrailingZeros - -evalFDiv :: FloatingType a -> ((a, a) -> a) -evalFDiv ty | FloatingDict <- floatingDict ty = uncurry (/) - -evalRecip :: FloatingType a -> (a -> a) -evalRecip ty | FloatingDict <- floatingDict ty = recip - - -evalLt :: SingleType a -> ((a, a) -> PrimBool) -evalLt (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (<) -evalLt (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (<) - -evalGt :: SingleType a -> ((a, a) -> PrimBool) -evalGt (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (>) -evalGt (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (>) - -evalLtEq :: SingleType a -> ((a, a) -> PrimBool) -evalLtEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (<=) -evalLtEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (<=) - -evalGtEq :: SingleType a -> ((a, a) -> PrimBool) -evalGtEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (>=) -evalGtEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (>=) - -evalEq :: SingleType a -> ((a, a) -> PrimBool) -evalEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (==) -evalEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (==) - -evalNEq :: SingleType a -> ((a, a) -> PrimBool) -evalNEq (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = fromBool . uncurry (/=) -evalNEq (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = fromBool . uncurry (/=) - -evalMax :: SingleType a -> ((a, a) -> a) -evalMax (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = uncurry max -evalMax (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = uncurry max - -evalMin :: SingleType a -> ((a, a) -> a) -evalMin (NumSingleType (IntegralNumType ty)) | IntegralDict <- integralDict ty = uncurry min -evalMin (NumSingleType (FloatingNumType ty)) | FloatingDict <- floatingDict ty = uncurry min - - -{-- --- Sequence evaluation --- --------------- - --- Position in sequence. --- -type SeqPos = Int - --- Configuration for sequence evaluation. --- -data SeqConfig = SeqConfig - { chunkSize :: Int -- Allocation limit for a sequence in - -- words. Actual runtime allocation should be the - -- maximum of this size and the size of the - -- largest element in the sequence. - } - --- Default sequence evaluation configuration for testing purposes. --- -defaultSeqConfig :: SeqConfig -defaultSeqConfig = SeqConfig { chunkSize = 2 } - -type Chunk a = Vector' a - --- The empty chunk. O(1). -emptyChunk :: Arrays a => Chunk a -emptyChunk = empty' - --- Number of arrays in chunk. O(1). --- -clen :: Arrays a => Chunk a -> Int -clen = length' - -elemsPerChunk :: SeqConfig -> Int -> Int -elemsPerChunk conf n - | n < 1 = chunkSize conf - | otherwise = - let (a,b) = chunkSize conf `quotRem` n - in a + signum b - --- Drop a number of arrays from a chunk. O(1). Note: Require keeping a --- scan of element sizes. --- -cdrop :: Arrays a => Int -> Chunk a -> Chunk a -cdrop = drop' dropOp (fst . offsetsOp) - --- Get all the shapes of a chunk of arrays. O(1). --- -chunkShapes :: Chunk (Array sh a) -> Vector sh -chunkShapes = shapes' - --- Get all the elements of a chunk of arrays. O(1). --- -chunkElems :: Chunk (Array sh a) -> Vector a -chunkElems = elements' - --- Convert a vector to a chunk of scalars. --- -vec2Chunk :: Elt e => Vector e -> Chunk (Scalar e) -vec2Chunk = vec2Vec' - --- Convert a list of arrays to a chunk. --- -fromListChunk :: Arrays a => [a] -> Vector' a -fromListChunk = fromList' concatOp - --- Convert a chunk to a list of arrays. --- -toListChunk :: Arrays a => Vector' a -> [a] -toListChunk = toList' fetchAllOp - --- fmap for Chunk. O(n). --- TODO: Use vectorised function. -mapChunk :: (Arrays a, Arrays b) - => (a -> b) - -> Chunk a -> Chunk b -mapChunk f c = fromListChunk $ map f (toListChunk c) - --- zipWith for Chunk. O(n). --- TODO: Use vectorised function. -zipWithChunk :: (Arrays a, Arrays b, Arrays c) - => (a -> b -> c) - -> Chunk a -> Chunk b -> Chunk c -zipWithChunk f c1 c2 = fromListChunk $ zipWith f (toListChunk c1) (toListChunk c2) - --- A window on a sequence. --- -data Window a = Window - { chunk :: Chunk a -- Current allocated chunk. - , wpos :: SeqPos -- Position of the window on the sequence, given - -- in number of elements. - } - --- The initial empty window. --- -window0 :: Arrays a => Window a -window0 = Window { chunk = emptyChunk, wpos = 0 } - --- Index the given window by the given index on the sequence. --- -(!#) :: Arrays a => Window a -> SeqPos -> Chunk a -w !# i - | j <- i - wpos w - , j >= 0 - = cdrop j (chunk w) - -- - | otherwise - = error $ "Window indexed before position. wpos = " ++ show (wpos w) ++ " i = " ++ show i - --- Move the give window by supplying the next chunk. --- -moveWin :: Arrays a => Window a -> Chunk a -> Window a -moveWin w c = w { chunk = c - , wpos = wpos w + clen (chunk w) - } - --- A cursor on a sequence. --- -data Cursor senv a = Cursor - { ref :: Idx senv a -- Reference to the sequence. - , cpos :: SeqPos -- Position of the cursor on the sequence, - -- given in number of elements. - } - --- Initial cursor. --- -cursor0 :: Idx senv a -> Cursor senv a -cursor0 x = Cursor { ref = x, cpos = 0 } - --- Advance cursor by a relative amount. --- -moveCursor :: Int -> Cursor senv a -> Cursor senv a -moveCursor k c = c { cpos = cpos c + k } +evalPrim = \case + PrimAdd t -> A.add t + PrimSub t -> A.sub t + PrimMul t -> A.mul t + PrimNeg t -> A.negate t + PrimAbs t -> A.abs t + PrimSig t -> A.signum t + PrimVAdd t -> A.vadd t + PrimVMul t -> A.vmul t + PrimQuot t -> A.quot t + PrimRem t -> A.rem t + PrimQuotRem t -> A.quotRem t + PrimIDiv t -> A.div t + PrimMod t -> A.mod t + PrimDivMod t -> A.divMod t + PrimBAnd t -> A.band t + PrimBOr t -> A.bor t + PrimBXor t -> A.xor t + PrimBNot t -> A.complement t + PrimBShiftL t -> A.shiftL t + PrimBShiftR t -> A.shiftR t + PrimBRotateL t -> A.rotateL t + PrimBRotateR t -> A.rotateR t + PrimPopCount t -> A.popCount t + PrimCountLeadingZeros t -> A.countLeadingZeros t + PrimCountTrailingZeros t -> A.countTrailingZeros t + PrimBReverse t -> A.bitreverse t + PrimBSwap t -> A.byteswap t + PrimVBAnd t -> A.vband t + PrimVBOr t -> A.vbor t + PrimVBXor t -> A.vbxor t + PrimFDiv t -> A.fdiv t + PrimRecip t -> A.recip t + PrimSin t -> A.sin t + PrimCos t -> A.cos t + PrimTan t -> A.tan t + PrimAsin t -> A.asin t + PrimAcos t -> A.acos t + PrimAtan t -> A.atan t + PrimSinh t -> A.sinh t + PrimCosh t -> A.cosh t + PrimTanh t -> A.tanh t + PrimAsinh t -> A.asinh t + PrimAcosh t -> A.acosh t + PrimAtanh t -> A.atanh t + PrimExpFloating t -> A.exp t + PrimSqrt t -> A.sqrt t + PrimLog t -> A.log t + PrimFPow t -> A.pow t + PrimLogBase t -> A.logBase t + PrimTruncate ta tb -> A.truncate ta tb + PrimRound ta tb -> A.round ta tb + PrimFloor ta tb -> A.floor ta tb + PrimCeiling ta tb -> A.ceiling ta tb + PrimAtan2 t -> A.atan2 t + PrimIsNaN t -> A.isNaN t + PrimIsInfinite t -> A.isInfinite t + PrimLt t -> A.lt t + PrimGt t -> A.gt t + PrimLtEq t -> A.lte t + PrimGtEq t -> A.gte t + PrimEq t -> A.eq t + PrimNEq t -> A.neq t + PrimMin t -> A.min t + PrimMax t -> A.max t + PrimVMin t -> A.vmin t + PrimVMax t -> A.vmax t + PrimLAnd t -> A.land t + PrimLOr t -> A.lor t + PrimLNot t -> A.lnot t + PrimVLAnd t -> A.vland t + PrimVLOr t -> A.vlor t + PrimFromIntegral ta tb -> A.fromIntegral ta tb + PrimToFloating ta tb -> A.toFloating ta tb + PrimToBool i b -> A.toBool i b + PrimFromBool b i -> A.fromBool b i + + +-- Vector primitives +-- ----------------- --- Valuation for an environment of sequence windows. --- -data Val' senv where - Empty' :: Val' () - Push' :: Val' senv -> Window t -> Val' (senv, t) +evalExtract :: ScalarType (Prim.Vec n a) -> SingleIntegralType i -> Prim.Vec n a -> i -> a +evalExtract vR iR v i = scalar vR v + where + scalar :: ScalarType (Prim.Vec n t) -> Prim.Vec n t -> t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t --- Projection of a window from a window valuation using a de Bruijn --- index. --- -prj' :: Idx senv t -> Val' senv -> Window t -prj' ZeroIdx (Push' _ v) = v -prj' (SuccIdx idx) (Push' val _) = prj' idx val + num :: NumType (Prim.Vec n t) -> Prim.Vec n t -> t + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t --- Projection of a chunk from a window valuation using a sequence --- cursor. --- -prjChunk :: Arrays a => Cursor senv a -> Val' senv -> Chunk a -prjChunk c senv = prj' (ref c) senv !# cpos c + bit :: BitType (Prim.Vec n t) -> Prim.Vec n t -> t + bit TypeMask{} v + | IntegralDict <- integralDict iR + = Bit.extract (BitMask v) (fromIntegral i) + + integral :: IntegralType (Prim.Vec n t) -> Prim.Vec n t -> t + integral (SingleIntegralType tR) _ = case tR of + integral (VectorIntegralType _ tR) v + | IntegralDict <- integralDict tR + , IntegralDict <- integralDict iR + = Vec.extract v (fromIntegral i) + + floating :: FloatingType (Prim.Vec n t) -> Prim.Vec n t -> t + floating (SingleFloatingType tR) _ = case tR of + floating (VectorFloatingType _ tR) v + | FloatingDict <- floatingDict tR + , IntegralDict <- integralDict iR + = Vec.extract v (fromIntegral i) + +evalInsert + :: ScalarType (Prim.Vec n a) + -> SingleIntegralType i + -> Prim.Vec n a + -> i + -> a + -> Prim.Vec n a +evalInsert vR iR v i x = scalar vR v x + where + scalar :: ScalarType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t --- An executable sequence. --- -data ExecSeq senv arrs where - ExecP :: Arrays a => Window a -> ExecP senv a -> ExecSeq (senv, a) arrs -> ExecSeq senv arrs - ExecC :: Arrays a => ExecC senv a -> ExecSeq senv a - ExecR :: Arrays a => Cursor senv a -> ExecSeq senv [a] + num :: NumType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t --- An executable producer. --- -data ExecP senv a where - ExecStreamIn :: Int - -> [a] - -> ExecP senv a - - ExecMap :: Arrays a - => (Chunk a -> Chunk b) - -> Cursor senv a - -> ExecP senv b - - ExecZipWith :: (Arrays a, Arrays b) - => (Chunk a -> Chunk b -> Chunk c) - -> Cursor senv a - -> Cursor senv b - -> ExecP senv c - - -- Stream scan skeleton. - ExecScan :: Arrays a - => (s -> Chunk a -> (Chunk r, s)) -- Chunk scanner. - -> s -- Accumulator (internal state). - -> Cursor senv a -- Input stream. - -> ExecP senv r - --- An executable consumer. --- -data ExecC senv a where - - -- Stream reduction skeleton. - ExecFold :: Arrays a - => (s -> Chunk a -> s) -- Chunk consumer function. - -> (s -> r) -- Finalizer function. - -> s -- Accumulator (internal state). - -> Cursor senv a -- Input stream. - -> ExecC senv r - - ExecStuple :: IsAtuple a - => Atuple (ExecC senv) (TupleRepr a) - -> ExecC senv a - -minCursor :: ExecSeq senv a -> SeqPos -minCursor s = travS s 0 + bit :: BitType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + bit TypeMask{} v x + | IntegralDict <- integralDict iR + = unMask $ Bit.insert (BitMask v) (fromIntegral i) x + + integral :: IntegralType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + integral (SingleIntegralType tR) _ _ = case tR of + integral (VectorIntegralType _ tR) v x + | IntegralDict <- integralDict tR + , IntegralDict <- integralDict iR + = Vec.insert v (fromIntegral i) x + + floating :: FloatingType (Prim.Vec n t) -> Prim.Vec n t -> t -> Prim.Vec n t + floating (SingleFloatingType tR) _ _ = case tR of + floating (VectorFloatingType _ tR) v x + | FloatingDict <- floatingDict tR + , IntegralDict <- integralDict iR + = Vec.insert v (fromIntegral i) x + +evalShuffle + :: ScalarType (Prim.Vec n a) + -> ScalarType (Prim.Vec m a) + -> SingleIntegralType i + -> Prim.Vec n a + -> Prim.Vec n a + -> Prim.Vec m i + -> Prim.Vec m a +evalShuffle = scalar where - travS :: ExecSeq senv a -> Int -> SeqPos - travS s i = - case s of - ExecP _ p s' -> travP p i `min` travS s' (i+1) - ExecC c -> travC c i - ExecR _ -> maxBound - - k :: Cursor senv a -> Int -> SeqPos - k c i - | i == idxToInt (ref c) = cpos c - | otherwise = maxBound - - travP :: ExecP senv a -> Int -> SeqPos - travP p i = - case p of - ExecStreamIn _ _ -> maxBound - ExecMap _ c -> k c i - ExecZipWith _ c1 c2 -> k c1 i `min` k c2 i - ExecScan _ _ c -> k c i - - travT :: Atuple (ExecC senv) t -> Int -> SeqPos - travT NilAtup _ = maxBound - travT (SnocAtup t c) i = travT t i `min` travC c i - - travC :: ExecC senv a -> Int -> SeqPos - travC c i = - case c of - ExecFold _ _ _ cu -> k cu i - ExecStuple t -> travT t i - - -evalDelayedSeq - :: SeqConfig - -> DelayedSeq arrs - -> arrs -evalDelayedSeq cfg (DelayedSeq aenv s) | aenv' <- evalExtend aenv Empty - = evalSeq cfg s aenv' - -evalSeq :: forall aenv arrs. - SeqConfig - -> PreOpenSeq DelayedOpenAcc aenv () arrs - -> Val aenv -> arrs -evalSeq conf s aenv = evalSeq' s + scalar :: ScalarType (Prim.Vec n t) + -> ScalarType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + scalar (NumScalarType s) (NumScalarType t) = num s t + scalar (BitScalarType s) (BitScalarType t) = bit s t + scalar _ _ = internalError "unexpected vector encoding" + + num :: NumType (Prim.Vec n t) + -> NumType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + num (IntegralNumType s) (IntegralNumType t) = integral s t + num (FloatingNumType s) (FloatingNumType t) = floating s t + num _ _ = internalError "unexpected vector encoding" + + bit :: BitType (Prim.Vec n t) + -> BitType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + bit (TypeMask n#) TypeMask{} iR x y i + | IntegralDict <- integralDict iR + = let n = fromInteger (natVal' n#) + in unMask + $ Bit.fromList [ boundsCheck "vector index" (j >= 0 && j < 2*n) + $ if j < n then Bit.extract (BitMask x) j + else Bit.extract (BitMask y) (j - n) + | j <- map fromIntegral (Vec.toList i) ] + + integral :: IntegralType (Prim.Vec n t) + -> IntegralType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + integral (SingleIntegralType s) _ _ _ _ _ = case s of + integral _ (SingleIntegralType t) _ _ _ _ = case t of + integral (VectorIntegralType n# sR) (VectorIntegralType _ tR) iR x y i + | IntegralDict <- integralDict iR + , IntegralDict <- integralDict sR + , IntegralDict <- integralDict tR + = let n = fromInteger (natVal' n#) + in Vec.fromList [ boundsCheck "vector index" (j >= 0 && j < 2*n) + $ if j < n then Vec.extract x j + else Vec.extract y (j - n) + | j <- map fromIntegral (Vec.toList i) ] + + floating :: FloatingType (Prim.Vec n t) + -> FloatingType (Prim.Vec m t) + -> SingleIntegralType i + -> Prim.Vec n t + -> Prim.Vec n t + -> Prim.Vec m i + -> Prim.Vec m t + floating (SingleFloatingType s) _ _ _ _ _ = case s of + floating _ (SingleFloatingType t) _ _ _ _ = case t of + floating (VectorFloatingType n# sR) (VectorFloatingType _ tR) iR x y i + | IntegralDict <- integralDict iR + , FloatingDict <- floatingDict sR + , FloatingDict <- floatingDict tR + = let n = fromInteger (natVal' n#) + in Vec.fromList [ boundsCheck "vector index" (j >= 0 && j < 2*n) + $ if j < n then Vec.extract x j + else Vec.extract y (j - n) + | j <- map fromIntegral (Vec.toList i) ] + +evalSelect + :: ScalarType (Prim.Vec n a) + -> Prim.Vec n Bit + -> Prim.Vec n a + -> Prim.Vec n a + -> Prim.Vec n a +evalSelect = scalar where - evalSeq' :: PreOpenSeq DelayedOpenAcc aenv senv arrs -> arrs - evalSeq' (Producer _ s) = evalSeq' s - evalSeq' (Consumer _) = loop (initSeq aenv s) - evalSeq' (Reify _) = reify (initSeq aenv s) - - -- Initialize the producers and the accumulators of the consumers - -- with the given array enviroment. - initSeq :: forall senv arrs'. - Val aenv - -> PreOpenSeq DelayedOpenAcc aenv senv arrs' - -> ExecSeq senv arrs' - initSeq aenv s = - case s of - Producer p s' -> ExecP window0 (initProducer p) (initSeq aenv s') - Consumer c -> ExecC (initConsumer c) - Reify ix -> ExecR (cursor0 ix) - - -- Generate a list from the sequence. - reify :: forall arrs. ExecSeq () [arrs] - -> [arrs] - reify s = case step s Empty' of - (Just s', a) -> a ++ reify s' - (Nothing, a) -> a - - -- Iterate the given sequence until it terminates. - -- A sequence only terminates when one of the producers are exhausted. - loop :: Arrays arrs - => ExecSeq () arrs - -> arrs - loop s = - case step' s of - (Nothing, arrs) -> arrs - (Just s', _) -> loop s' + scalar :: ScalarType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - where - step' :: ExecSeq () arrs -> (Maybe (ExecSeq () arrs), arrs) - step' s = step s Empty' - - -- One iteration of a sequence. - step :: forall senv arrs'. - ExecSeq senv arrs' - -> Val' senv - -> (Maybe (ExecSeq senv arrs'), arrs') - step s senv = - case s of - ExecP w p s' -> - let (c, mp') = produce p senv - finished = 0 == clen (w !# minCursor s') - w' = if finished then moveWin w c else w - (ms'', a) = step s' (senv `Push'` w') - in case ms'' of - Nothing -> (Nothing, a) - Just s'' | finished - , Just p' <- mp' - -> (Just (ExecP w' p' s''), a) - | not finished - -> (Just (ExecP w' p s''), a) - | otherwise - -> (Nothing, a) - ExecC c -> let (c', acc) = consume c senv - in (Just (ExecC c'), acc) - ExecR ix -> let c = prjChunk ix senv in (Just (ExecR (moveCursor (clen c) ix)), toListChunk c) - - evalA :: DelayedOpenAcc aenv a -> a - evalA acc = evalOpenAcc acc aenv - - evalAF :: DelayedOpenAfun aenv f -> f - evalAF f = evalOpenAfun f aenv - - evalE :: DelayedExp aenv t -> t - evalE exp = evalExp exp aenv - - evalF :: DelayedFun aenv f -> f - evalF fun = evalFun fun aenv - - initProducer :: forall a senv. - Producer DelayedOpenAcc aenv senv a - -> ExecP senv a - initProducer p = - case p of - StreamIn arrs -> ExecStreamIn 1 arrs - ToSeq sliceIndex slix (delayed -> Delayed sh ix _) -> - let n = R.size (R.sliceShape sliceIndex (fromElt sh)) - k = elemsPerChunk conf n - in ExecStreamIn k (toSeqOp sliceIndex slix (fromFunction sh ix)) - MapSeq f x -> ExecMap (mapChunk (evalAF f)) (cursor0 x) - ChunkedMapSeq f x -> ExecMap (evalAF f) (cursor0 x) - ZipWithSeq f x y -> ExecZipWith (zipWithChunk (evalAF f)) (cursor0 x) (cursor0 y) - ScanSeq f e x -> ExecScan scanner (evalE e) (cursor0 x) - where - scanner a c = - let v0 = chunkElems c - (v1, a') = scanl'Op (evalF f) a (delayArray v0) - in (vec2Chunk v1, fromScalar a') - - initConsumer :: forall a senv. - Consumer DelayedOpenAcc aenv senv a - -> ExecC senv a - initConsumer c = - case c of - FoldSeq f e x -> - let f' = evalF f - a0 = fromFunction (Z :. chunkSize conf) (const (evalE e)) - consumer v c = zipWith'Op f' (delayArray v) (delayArray (chunkElems c)) - finalizer = fold1Op f' . delayArray - in ExecFold consumer finalizer a0 (cursor0 x) - FoldSeqFlatten f acc x -> - let f' = evalAF f - a0 = evalA acc - consumer a c = f' a (chunkShapes c) (chunkElems c) - in ExecFold consumer id a0 (cursor0 x) - Stuple t -> - let initTup :: Atuple (Consumer DelayedOpenAcc aenv senv) t -> Atuple (ExecC senv) t - initTup NilAtup = NilAtup - initTup (SnocAtup t c) = SnocAtup (initTup t) (initConsumer c) - in ExecStuple (initTup t) - - delayed :: DelayedOpenAcc aenv (Array sh e) -> Delayed (Array sh e) - delayed AST.Manifest{} = $internalError "evalOpenAcc" "expected delayed array" - delayed AST.Delayed{..} = Delayed (evalExp extentD aenv) - (evalFun indexD aenv) - (evalFun linearIndexD aenv) - -produce :: Arrays a => ExecP senv a -> Val' senv -> (Chunk a, Maybe (ExecP senv a)) -produce p senv = - case p of - ExecStreamIn k xs -> - let (xs', xs'') = (take k xs, drop k xs) - c = fromListChunk xs' - mp = if null xs'' - then Nothing - else Just (ExecStreamIn k xs'') - in (c, mp) - ExecMap f x -> - let c = prjChunk x senv - in (f c, Just $ ExecMap f (moveCursor (clen c) x)) - ExecZipWith f x y -> - let c1 = prjChunk x senv - c2 = prjChunk y senv - k = clen c1 `min` clen c2 - in (f c1 c2, Just $ ExecZipWith f (moveCursor k x) (moveCursor k y)) - ExecScan scanner a x -> - let c = prjChunk x senv - (c', a') = scanner a c - k = clen c - in (c', Just $ ExecScan scanner a' (moveCursor k x)) - -consume :: forall senv a. ExecC senv a -> Val' senv -> (ExecC senv a, a) -consume c senv = - case c of - ExecFold f g acc x -> - let c = prjChunk x senv - acc' = f acc c - -- Even though we call g here, lazy evaluation should guarantee it is - -- only ever called once. - in (ExecFold f g acc' (moveCursor (clen c) x), g acc') - ExecStuple t -> - let consT :: Atuple (ExecC senv) t -> (Atuple (ExecC senv) t, t) - consT NilAtup = (NilAtup, ()) - consT (SnocAtup t c) | (c', acc) <- consume c senv - , (t', acc') <- consT t - = (SnocAtup t' c', (acc', acc)) - (t', acc) = consT t - in (ExecStuple t', toAtuple acc) - -evalExtend :: Extend DelayedOpenAcc aenv aenv' -> Val aenv -> Val aenv' -evalExtend BaseEnv aenv = aenv -evalExtend (PushEnv ext1 ext2) aenv | aenv' <- evalExtend ext1 aenv - = Push aenv' (evalOpenAcc ext2 aenv') - -delayArray :: Array sh e -> Delayed (Array sh e) -delayArray arr@(Array _ adata) = Delayed (shape arr) (arr!) (toElt . unsafeIndexArrayData adata) - -fromScalar :: Scalar a -> a -fromScalar = (!Z) - -concatOp :: forall e. Elt e => [Vector e] -> Vector e -concatOp = concatVectors - -fetchAllOp :: (Shape sh, Elt e) => Segments sh -> Vector e -> [Array sh e] -fetchAllOp segs elts - | (offsets, n) <- offsetsOp segs - , (n ! Z) <= size (shape elts) - = [fetch (segs ! (Z :. i)) (offsets ! (Z :. i)) | i <- [0 .. size (shape segs) - 1]] - | otherwise = error $ "illegal argument to fetchAllOp" - where - fetch sh offset = fromFunction sh (\ ix -> elts ! (Z :. ((toIndex sh ix) + offset))) - -dropOp :: Elt e => Int -> Vector e -> Vector e -dropOp i v -- TODO - -- * Implement using C-style pointer-plus. - -- ; dropOp is used often (from prjChunk), - -- so it ought to be efficient O(1). - | n <- size (shape v) - , i <= n - , i >= 0 - = fromFunction (Z :. n - i) (\ (Z :. j) -> v ! (Z :. i + j)) - | otherwise = error $ "illegal argument to drop" - -offsetsOp :: Shape sh => Segments sh -> (Vector Int, Scalar Int) -offsetsOp segs = scanl'Op (+) 0 $ delayArray (mapOp size (delayArray segs)) ---} + num :: NumType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + bit :: BitType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + bit TypeMask{} m x y + = unMask + $ Bit.fromList [ if unBit b then Bit.extract (BitMask x) i + else Bit.extract (BitMask y) i + | b <- Bit.toList (BitMask m) + | i <- [0..] + ] + + integral :: IntegralType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + integral (SingleIntegralType tR) _ _ _ = case tR of + integral (VectorIntegralType _ tR) m x y + | IntegralDict <- integralDict tR + = Vec.fromList [ if unBit b then Vec.extract x i + else Vec.extract y i + | b <- Bit.toList (BitMask m) + | i <- [0..] + ] + + floating :: FloatingType (Prim.Vec n t) -> Prim.Vec n Bit -> Prim.Vec n t -> Prim.Vec n t -> Prim.Vec n t + floating (SingleFloatingType tR) _ _ _ = case tR of + floating (VectorFloatingType _ tR) m x y + | FloatingDict <- floatingDict tR + = Vec.fromList [ if unBit b then Vec.extract x i + else Vec.extract y i + | b <- Bit.toList (BitMask m) + | i <- [0..] + ] + + +-- Utilities +-- --------- + +toBool :: PrimBool -> Bool +toBool = unBit + +data IntegralDict t where + IntegralDict :: (Integral t, Prim t) => IntegralDict t + +data FloatingDict t where + FloatingDict :: (RealFloat t, Prim t) => FloatingDict t + +data TagDict t where + TagDict :: Eq t => TagDict t + +{-# INLINE integralDict #-} +integralDict :: SingleIntegralType t -> IntegralDict t +integralDict TypeInt8 = IntegralDict +integralDict TypeInt16 = IntegralDict +integralDict TypeInt32 = IntegralDict +integralDict TypeInt64 = IntegralDict +integralDict TypeInt128 = IntegralDict +integralDict TypeWord8 = IntegralDict +integralDict TypeWord16 = IntegralDict +integralDict TypeWord32 = IntegralDict +integralDict TypeWord64 = IntegralDict +integralDict TypeWord128 = IntegralDict + +{-# INLINE floatingDict #-} +floatingDict :: SingleFloatingType t -> FloatingDict t +floatingDict TypeFloat16 = FloatingDict +floatingDict TypeFloat32 = FloatingDict +floatingDict TypeFloat64 = FloatingDict +floatingDict TypeFloat128 = FloatingDict + +{-# INLINE tagDict #-} +tagDict :: TagType t -> TagDict t +tagDict TagBit = TagDict +tagDict TagWord8 = TagDict +tagDict TagWord16 = TagDict diff --git a/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs b/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs new file mode 100644 index 000000000..cc914d1a2 --- /dev/null +++ b/src/Data/Array/Accelerate/Interpreter/Arithmetic.hs @@ -0,0 +1,722 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_HADDOCK hide #-} +-- | +-- Module : Data.Array.Accelerate.Interpreter.Arithmetic +-- Copyright : [2008..2022] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Interpreter.Arithmetic ( + + add, sub, mul, negate, abs, signum, vadd, vmul, + quot, rem, quotRem, div, mod, divMod, + band, bor, xor, complement, shiftL, shiftR, rotateL, rotateR, popCount, countLeadingZeros, countTrailingZeros, bitreverse, byteswap, vband, vbor, vbxor, + fdiv, recip, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, exp, sqrt, log, pow, logBase, + truncate, round, floor, ceiling, + atan2, isNaN, isInfinite, + lt, gt, lte, gte, eq, neq, min, max, vmin, vmax, + land, lor, lnot, vland, vlor, + fromIntegral, toFloating, toBool, fromBool, + +) where + +import Data.Array.Accelerate.AST +import Data.Array.Accelerate.Error +import Data.Array.Accelerate.Type + +import Data.Primitive.Bit ( BitMask(..) ) +import Data.Primitive.Vec ( Vec ) +import qualified Data.Primitive.Bit as Bit +import qualified Data.Primitive.Vec as Vec + +import Data.Primitive.Types + +import Data.Bits ( (.&.), (.|.) ) +import Data.Bool +import Data.Maybe +import Data.Type.Equality +import Formatting +import Prelude ( ($), (.) ) +import qualified Data.Bits as P +import qualified Prelude as P + +import GHC.Int +import GHC.Word +import GHC.Exts +import GHC.TypeLits +import GHC.TypeLits.Extra + + +-- Operators from Num +-- ------------------ + +add :: NumType a -> ((a, a) -> a) +add = num2 (P.+) + +sub :: NumType a -> ((a, a) -> a) +sub = num2 (P.-) + +mul :: NumType a -> ((a, a) -> a) +mul = num2 (P.*) + +negate :: NumType a -> (a -> a) +negate = num1 P.negate + +abs :: NumType a -> (a -> a) +abs = num1 P.abs + +signum :: NumType a -> (a -> a) +signum = num1 P.signum + +vadd :: NumType (Vec n a) -> (Vec n a -> a) +vadd = vnum2 (P.+) + +vmul :: NumType (Vec n a) -> (Vec n a -> a) +vmul = vnum2 (P.*) + +num1 :: (forall t. P.Num t => t -> t) -> NumType a -> (a -> a) +num1 f = \case + IntegralNumType t -> integral t + FloatingNumType t -> floating t + where + integral :: IntegralType t -> (t -> t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = f + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = map f + -- + floating :: FloatingType t -> (t -> t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = f + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = map f + +num2 :: (forall t. P.Num t => t -> t -> t) -> NumType a -> ((a, a) -> a) +num2 f = \case + IntegralNumType t -> integral t + FloatingNumType t -> floating t + where + integral :: IntegralType t -> ((t, t) -> t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) + -- + floating :: FloatingType t -> ((t, t) -> t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry f + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith f) + +vnum2 :: (forall t. P.Num t => t -> t -> t) -> NumType (Vec n a) -> (Vec n a -> a) +vnum2 f = \case + IntegralNumType t -> integral t + FloatingNumType t -> floating t + where + integral :: IntegralType (Vec n t) -> (Vec n t -> t) + integral = \case + SingleIntegralType t -> case t of + VectorIntegralType _ t | IntegralDict <- integralDict t -> P.foldl1 f . Vec.toList + + floating :: FloatingType (Vec n t) -> (Vec n t -> t) + floating = \case + SingleFloatingType t -> case t of + VectorFloatingType _ t | FloatingDict <- floatingDict t -> P.foldl1 f . Vec.toList + +-- Operators from Integral +-- ----------------------- + +quot :: IntegralType a -> ((a, a) -> a) +quot = int2 P.quot + +rem :: IntegralType a -> ((a, a) -> a) +rem = int2 P.rem + +quotRem :: IntegralType a -> ((a, a) -> (a, a)) +quotRem = int2' P.quotRem + +div :: IntegralType a -> ((a, a) -> a) +div = int2 P.div + +mod :: IntegralType a -> ((a, a) -> a) +mod = int2 P.mod + +divMod :: IntegralType a -> ((a, a) -> (a, a)) +divMod = int2' P.divMod + +int2 :: (forall t. P.Integral t => t -> t -> t) -> IntegralType a -> ((a, a) -> a) +int2 f (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f +int2 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) + +int2' :: (forall t. P.Integral t => t -> t -> (t, t)) -> IntegralType a -> ((a, a) -> (a, a)) +int2' f (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f +int2' f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith' f) + + +-- Operators from Bits & FiniteBits +-- -------------------------------- + +band :: IntegralType a -> ((a, a) -> a) +band = bits2 (.&.) + +bor :: IntegralType a -> ((a, a) -> a) +bor = bits2 (.|.) + +xor :: IntegralType a -> ((a, a) -> a) +xor = bits2 P.xor + +complement :: IntegralType a -> (a -> a) +complement = bits1 P.complement + +shiftL :: IntegralType a -> ((a, a) -> a) +shiftL = bits2 (\x i -> x `P.shiftL` P.fromIntegral i) + +shiftR :: IntegralType a -> ((a, a) -> a) +shiftR = bits2 (\x i -> x `P.shiftR` P.fromIntegral i) + +rotateL :: IntegralType a -> ((a, a) -> a) +rotateL = bits2 (\x i -> x `P.rotateL` P.fromIntegral i) + +rotateR :: IntegralType a -> ((a, a) -> a) +rotateR = bits2 (\x i -> x `P.rotateR` P.fromIntegral i) + +popCount :: IntegralType a -> (a -> a) +popCount = bits1 (P.fromIntegral . P.popCount) + +countLeadingZeros :: IntegralType a -> (a -> a) +countLeadingZeros = bits1 (P.fromIntegral . P.countLeadingZeros) + +countTrailingZeros :: IntegralType a -> (a -> a) +countTrailingZeros = bits1 (P.fromIntegral . P.countTrailingZeros) + +bitreverse :: IntegralType a -> (a -> a) +bitreverse = \case + VectorIntegralType _ t | IntegralDict <- integralDict t -> Vec.fromList . P.map (bitreverse (SingleIntegralType t)) . Vec.toList + SingleIntegralType t -> single t + where + single :: SingleIntegralType t -> t -> t + single TypeInt8 (I8# i#) = I8# (word8ToInt8# (wordToWord8# (bitReverse8# (word8ToWord# (int8ToWord8# i#))))) + single TypeInt16 (I16# i#) = I16# (word16ToInt16# (wordToWord16# (bitReverse16# (word16ToWord# (int16ToWord16# i#))))) + single TypeInt32 (I32# i#) = I32# (word32ToInt32# (wordToWord32# (bitReverse32# (word32ToWord# (int32ToWord32# i#))))) + single TypeInt64 (I64# i#) = I64# (word64ToInt64# (bitReverse64# (int64ToWord64# i#))) + single TypeInt128 (Int128 h l) = Int128 (bitreverse integralType l) (bitreverse integralType h) + single TypeWord8 (W8# w#) = W8# (wordToWord8# (bitReverse8# (word8ToWord# w#))) + single TypeWord16 (W16# w#) = W16# (wordToWord16# (bitReverse16# (word16ToWord# w#))) + single TypeWord32 (W32# w#) = W32# (wordToWord32# (bitReverse32# (word32ToWord# w#))) + single TypeWord64 (W64# w#) = W64# (bitReverse64# w#) + single TypeWord128 (Word128 h l) = Word128 (bitreverse integralType l) (bitreverse integralType h) + +byteswap :: IntegralType a -> (a -> a) +byteswap = \case + VectorIntegralType _ t | IntegralDict <- integralDict t -> Vec.fromList . P.map (byteswap (SingleIntegralType t)) . Vec.toList + SingleIntegralType t -> single t + where + single :: SingleIntegralType t -> t -> t + single TypeInt8 x = x + single TypeInt16 (I16# i#) = I16# (word16ToInt16# (wordToWord16# (byteSwap16# (word16ToWord# (int16ToWord16# i#))))) + single TypeInt32 (I32# i#) = I32# (word32ToInt32# (wordToWord32# (byteSwap32# (word32ToWord# (int32ToWord32# i#))))) + single TypeInt64 (I64# i#) = I64# (word64ToInt64# (byteSwap64# (int64ToWord64# i#))) + single TypeInt128 (Int128 h l) = Int128 (byteswap integralType l) (byteswap integralType h) + single TypeWord8 x = x + single TypeWord16 (W16# w#) = W16# (wordToWord16# (byteSwap16# (word16ToWord# w#))) + single TypeWord32 (W32# w#) = W32# (wordToWord32# (byteSwap32# (word32ToWord# w#))) + single TypeWord64 (W64# w#) = W64# (byteSwap64# w#) + single TypeWord128 (Word128 h l) = Word128 (byteswap integralType l) (byteswap integralType h) + +vband :: IntegralType (Vec n a) -> (Vec n a -> a) +vband = vbits2 (.&.) + +vbor :: IntegralType (Vec n a) -> (Vec n a -> a) +vbor = vbits2 (.|.) + +vbxor :: IntegralType (Vec n a) -> (Vec n a -> a) +vbxor = vbits2 P.xor + +bits1 :: (forall t. (P.Integral t, P.FiniteBits t) => t -> t) -> IntegralType a -> (a -> a) +bits1 f (SingleIntegralType t) | IntegralDict <- integralDict t = f +bits1 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = map f + +bits2 :: (forall t. (P.Integral t, P.FiniteBits t) => t -> t -> t) -> IntegralType a -> ((a, a) -> a) +bits2 f (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f +bits2 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) + +vbits2 :: (forall t. P.FiniteBits t => t -> t -> t) -> IntegralType (Vec n a) -> (Vec n a -> a) +vbits2 _ (SingleIntegralType t) = case t of +vbits2 f (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.foldl1 f . Vec.toList + + +-- Operators from Fractional and Floating +-- -------------------------------------- + +fdiv :: FloatingType a -> ((a, a) -> a) +fdiv = float2 (P./) + +recip :: FloatingType a -> (a -> a) +recip = float1 P.recip + +sin :: FloatingType a -> (a -> a) +sin = float1 P.sin + +cos :: FloatingType a -> (a -> a) +cos = float1 P.cos + +tan :: FloatingType a -> (a -> a) +tan = float1 P.tan + +asin :: FloatingType a -> (a -> a) +asin = float1 P.asin + +acos :: FloatingType a -> (a -> a) +acos = float1 P.acos + +atan :: FloatingType a -> (a -> a) +atan = float1 P.atan + +sinh :: FloatingType a -> (a -> a) +sinh = float1 P.sinh + +cosh :: FloatingType a -> (a -> a) +cosh = float1 P.cosh + +tanh :: FloatingType a -> (a -> a) +tanh = float1 P.tanh + +asinh :: FloatingType a -> (a -> a) +asinh = float1 P.asinh + +acosh :: FloatingType a -> (a -> a) +acosh = float1 P.acosh + +atanh :: FloatingType a -> (a -> a) +atanh = float1 P.atanh + +exp :: FloatingType a -> (a -> a) +exp = float1 P.exp + +sqrt :: FloatingType a -> (a -> a) +sqrt = float1 P.sqrt + +log :: FloatingType a -> (a -> a) +log = float1 P.log + +pow :: FloatingType a -> ((a, a) -> a) +pow = float2 (P.**) + +logBase :: FloatingType a -> ((a, a) -> a) +logBase = float2 P.logBase + + +float1 :: (forall t. P.RealFloat t => t -> t) -> FloatingType a -> (a -> a) +float1 f (SingleFloatingType t) | FloatingDict <- floatingDict t = f +float1 f (VectorFloatingType _ t) | FloatingDict <- floatingDict t = map f + +float2 :: (forall t. P.RealFloat t => t -> t -> t) -> FloatingType a -> ((a, a) -> a) +float2 f (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry f +float2 f (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith f) + + +-- Operators from RealFrac +-- ----------------------- + +truncate :: FloatingType a -> IntegralType b -> (a -> b) +truncate (SingleFloatingType a) (SingleIntegralType b) + | FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = P.truncate +truncate (VectorFloatingType n a) (VectorIntegralType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = map P.truncate +truncate a b + = internalError ("truncate: cannot reconcile `" % formatFloatingType % "' with `" % formatIntegralType % "'") a b + +round :: FloatingType a -> IntegralType b -> (a -> b) +round (SingleFloatingType a) (SingleIntegralType b) + | FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = P.round +round (VectorFloatingType n a) (VectorIntegralType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = map P.round +round a b + = internalError ("round: cannot reconcile `" % formatFloatingType % "' with `" % formatIntegralType % "'") a b + +floor :: FloatingType a -> IntegralType b -> (a -> b) +floor (SingleFloatingType a) (SingleIntegralType b) + | FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = P.floor +floor (VectorFloatingType n a) (VectorIntegralType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = map P.floor +floor a b + = internalError ("floor: cannot reconcile `" % formatFloatingType % "' with `" % formatIntegralType % "'") a b + +ceiling :: FloatingType a -> IntegralType b -> (a -> b) +ceiling (SingleFloatingType a) (SingleIntegralType b) + | FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = P.ceiling +ceiling (VectorFloatingType n a) (VectorIntegralType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , IntegralDict <- integralDict b + = map P.ceiling +ceiling a b + = internalError ("ceiling: cannot reconcile `" % formatFloatingType % "' with `" % formatIntegralType % "'") a b + + +-- Operators from RealFloat +-- ------------------------ + +atan2 :: FloatingType a -> ((a, a) -> a) +atan2 = float2 P.atan2 + +isNaN :: FloatingType a -> (a -> BitOrMask a) +isNaN = \case + SingleFloatingType t | FloatingDict <- floatingDict t -> isNaN' + VectorFloatingType _ t | FloatingDict <- floatingDict t -> unMask . map isNaN' + where + isNaN' x = Bit (P.isNaN x) + +isInfinite :: FloatingType a -> (a -> BitOrMask a) +isInfinite = \case + SingleFloatingType t | FloatingDict <- floatingDict t -> isInfinite' + VectorFloatingType _ t | FloatingDict <- floatingDict t -> unMask . map isInfinite' + where + isInfinite' x = Bit (P.isInfinite x) + + +-- Operators from Eq & Ord +-- ----------------------- + +lt :: ScalarType a -> ((a, a) -> BitOrMask a) +lt = cmp (P.<) + +gt :: ScalarType a -> ((a, a) -> BitOrMask a) +gt = cmp (P.>) + +lte :: ScalarType a -> ((a, a) -> BitOrMask a) +lte = cmp (P.<=) + +gte :: ScalarType a -> ((a, a) -> BitOrMask a) +gte = cmp (P.>=) + +eq :: ScalarType a -> ((a, a) -> BitOrMask a) +eq = cmp (P.==) + +neq :: ScalarType a -> ((a, a) -> BitOrMask a) +neq = cmp (P./=) + +cmp :: (forall t. P.Ord t => (t -> t -> Bool)) -> ScalarType a -> ((a, a) -> BitOrMask a) +cmp f = \case + NumScalarType t -> num t + BitScalarType t -> bit t + where + bit :: BitType t -> ((t, t) -> BitOrMask t) + bit TypeBit = Bit . P.uncurry f + bit TypeMask{} = \(x,y) -> unMask (zipWith (Bit $$ f) (BitMask x) (BitMask y)) + + num :: NumType t -> ((t, t) -> BitOrMask t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ((t, t) -> BitOrMask t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry (Bit $$ f) + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (unMask $$ zipWith (Bit $$ f)) + + floating :: FloatingType t -> ((t, t) -> BitOrMask t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry (Bit $$ f) + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (unMask $$ zipWith (Bit $$ f)) + +min :: ScalarType a -> ((a, a) -> a) +min = ord2 P.min + +max :: ScalarType a -> ((a, a) -> a) +max = ord2 P.max + +vmin :: ScalarType (Vec n a) -> (Vec n a -> a) +vmin = vord2 P.min + +vmax :: ScalarType (Vec n a) -> (Vec n a -> a) +vmax = vord2 P.max + +ord2 :: (forall t. P.Ord t => t -> t -> t) -> ScalarType a -> ((a, a) -> a) +ord2 f = \case + NumScalarType t -> num t + BitScalarType t -> bit t + where + bit :: BitType t -> ((t, t) -> t) + bit TypeBit = P.uncurry f + bit TypeMask{} = \(x,y) -> unMask (zipWith f (BitMask x) (BitMask y)) + + num :: NumType t -> ((t, t) -> t) + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> ((t, t) -> t) + integral (SingleIntegralType t) | IntegralDict <- integralDict t = P.uncurry f + integral (VectorIntegralType _ t) | IntegralDict <- integralDict t = P.uncurry (zipWith f) + + floating :: FloatingType t -> ((t, t) -> t) + floating (SingleFloatingType t) | FloatingDict <- floatingDict t = P.uncurry f + floating (VectorFloatingType _ t) | FloatingDict <- floatingDict t = P.uncurry (zipWith f) + + +vord2 :: (forall t. P.Ord t => t -> t -> t) -> ScalarType (Vec n a) -> (Vec n a -> a) +vord2 f = \case + NumScalarType t -> num t + BitScalarType t -> bit t + where + bit :: BitType (Vec n t) -> (Vec n t -> t) + bit TypeMask{} = P.foldl1 f . Bit.toList . BitMask + + num :: NumType (Vec n t) -> (Vec n t -> t) + num = \case + IntegralNumType t -> integral t + FloatingNumType t -> floating t + + integral :: IntegralType (Vec n t) -> (Vec n t -> t) + integral = \case + SingleIntegralType t -> case t of + VectorIntegralType _ t | IntegralDict <- integralDict t -> P.foldl1 f . Vec.toList + + floating :: FloatingType (Vec n t) -> (Vec n t -> t) + floating = \case + SingleFloatingType t -> case t of + VectorFloatingType _ t | FloatingDict <- floatingDict t -> P.foldl1 f . Vec.toList + +-- Logical operators +-- ----------------- + +land :: BitType a -> ((a, a) -> a) +land = \case + TypeBit -> P.uncurry land' + TypeMask{} -> \(x,y) -> unMask (zipWith land' (BitMask x) (BitMask y)) + where + land' (Bit x) (Bit y) = Bit (x && y) + +lor :: BitType a -> ((a, a) -> a) +lor = \case + TypeBit -> P.uncurry lor' + TypeMask{} -> \(x,y) -> unMask (zipWith lor' (BitMask x) (BitMask y)) + where + lor' (Bit x) (Bit y) = Bit (x || y) + +lnot :: BitType a -> (a -> a) +lnot = \case + TypeBit -> not' + TypeMask{} -> unMask . map not' . BitMask + where + not' (Bit x) = Bit (not x) + +vland :: BitType (Vec n a) -> (Vec n a -> a) +vland TypeMask{} x = Bit $ P.all unBit (toList (BitMask x)) + +vlor :: BitType (Vec n a) -> (Vec n a -> a) +vlor TypeMask{} x = Bit $ P.any unBit (toList (BitMask x)) + + +-- Conversion +-- ---------- + +fromIntegral :: forall a b. IntegralType a -> NumType b -> (a -> b) +fromIntegral (SingleIntegralType a) (IntegralNumType (SingleIntegralType b)) + | IntegralDict <- integralDict a + , IntegralDict <- integralDict b + = P.fromIntegral +fromIntegral (SingleIntegralType a) (FloatingNumType (SingleFloatingType b)) + | IntegralDict <- integralDict a + , FloatingDict <- floatingDict b + = P.fromIntegral +fromIntegral (VectorIntegralType n a) (IntegralNumType (VectorIntegralType m b)) + | Just Refl <- sameNat' n m + , IntegralDict <- integralDict a + , IntegralDict <- integralDict b + = map P.fromIntegral +fromIntegral (VectorIntegralType n a) (FloatingNumType (VectorFloatingType m b)) + | Just Refl <- sameNat' n m + , IntegralDict <- integralDict a + , FloatingDict <- floatingDict b + = map P.fromIntegral +fromIntegral a b + = internalError ("fromIntegral: cannot reconcile `" % formatIntegralType % "' with `" % formatNumType % "'") a b + + +toFloating :: forall a b. NumType a -> FloatingType b -> (a -> b) +toFloating (IntegralNumType (SingleIntegralType a)) (SingleFloatingType b) + | IntegralDict <- integralDict a + , FloatingDict <- floatingDict b + = P.realToFrac +toFloating (FloatingNumType (SingleFloatingType a)) (SingleFloatingType b) + | FloatingDict <- floatingDict a + , FloatingDict <- floatingDict b + = P.realToFrac +toFloating (IntegralNumType (VectorIntegralType n a)) (VectorFloatingType m b) + | Just Refl <- sameNat' n m + , IntegralDict <- integralDict a + , FloatingDict <- floatingDict b + = map P.realToFrac +toFloating (FloatingNumType (VectorFloatingType n a)) (VectorFloatingType m b) + | Just Refl <- sameNat' n m + , FloatingDict <- floatingDict a + , FloatingDict <- floatingDict b + = map P.realToFrac +toFloating a b + = internalError ("toFloating: cannot reconcile `" % formatNumType % "' with `" % formatFloatingType % "'") a b + + +toBool :: IntegralType a -> BitType b -> (a -> b) +toBool iR bR = + case iR of + SingleIntegralType t | IntegralDict <- integralDict t -> + case bR of + TypeBit -> P.fromIntegral + TypeMask n -> \x -> let m = P.finiteBitSize x + bits = P.map (Bit . P.testBit x) [0 .. m P.- 1] P.++ P.repeat (Bit False) + in + unMask $ fromList (P.reverse (P.take (P.fromIntegral (natVal' n)) bits)) + -- + VectorIntegralType _ t | IntegralDict <- integralDict t -> + case bR of + TypeBit -> \x -> P.fromIntegral (Vec.extract x 0) -- XXX: first or last lane? + TypeMask _ -> unMask . map P.fromIntegral + +fromBool :: forall a b. BitType a -> IntegralType b -> (a -> b) +fromBool bR iR = + case bR of + TypeBit -> + case iR of + SingleIntegralType t | IntegralDict <- integralDict t -> P.fromIntegral + VectorIntegralType _ t | IntegralDict <- integralDict t -> \x -> + if unBit x + then Vec.insert (Vec.splat 0) 0 1 -- XXX: first or last lane? + else Vec.splat 0 + -- + TypeMask _ -> + case iR of + SingleIntegralType t | IntegralDict <- integralDict t -> \x -> + let bits = toList (BitMask x) + m = P.finiteBitSize (P.undefined :: b) + -- + go !_ !w [] = w + go !i !w (Bit True : bs) = go (i P.+ 1) (P.setBit w i) bs + go !i !w (Bit False: bs) = go (i P.+ 1) w bs + in + go 0 0 (P.reverse (P.take m bits)) + -- + VectorIntegralType _ t | IntegralDict <- integralDict t -> map P.fromIntegral . BitMask + + +-- Vector element-wise operations +-- ------------------------------ + +-- XXX: These implementations lose the type safety that the length of the +-- underlying Vec or BitMask is preserved + +map :: (IsList a, IsList b) => (Item a -> Item b) -> a -> b +map f xs = fromList $ P.map f (toList xs) + +zipWith :: (IsList a, IsList b, IsList c) => (Item a -> Item b -> Item c) -> a -> b -> c +zipWith f xs ys = fromList $ P.zipWith f (toList xs) (toList ys) + +zipWith' + :: (IsList a, IsList b, IsList c, IsList d) + => (Item a -> Item b -> (Item c, Item d)) + -> a + -> b + -> (c, d) +zipWith' f xs ys = + let (us, vs) = P.unzip $ P.zipWith f (toList xs) (toList ys) + in (fromList us, fromList vs) + + +-- Utilities +-- --------- + +data IntegralDict t where + IntegralDict :: (P.Integral t, P.FiniteBits t, Prim t, BitOrMask t ~ Bit) => IntegralDict t + +data FloatingDict t where + FloatingDict :: (P.RealFloat t, Prim t, BitOrMask t ~ Bit) => FloatingDict t + +{-# INLINE integralDict #-} +integralDict :: SingleIntegralType t -> IntegralDict t +integralDict TypeInt8 = IntegralDict +integralDict TypeInt16 = IntegralDict +integralDict TypeInt32 = IntegralDict +integralDict TypeInt64 = IntegralDict +integralDict TypeInt128 = IntegralDict +integralDict TypeWord8 = IntegralDict +integralDict TypeWord16 = IntegralDict +integralDict TypeWord32 = IntegralDict +integralDict TypeWord64 = IntegralDict +integralDict TypeWord128 = IntegralDict + +{-# INLINE floatingDict #-} +floatingDict :: SingleFloatingType t -> FloatingDict t +floatingDict TypeFloat16 = FloatingDict +floatingDict TypeFloat32 = FloatingDict +floatingDict TypeFloat64 = FloatingDict +floatingDict TypeFloat128 = FloatingDict + +infixr 0 $$ +($$) :: (b -> a) -> (c -> d -> b) -> c -> d -> a +(f $$ g) x y = f (g x y) + +#if __GLASGOW_HASKELL__ < 904 +int64ToWord64# :: Int# -> Word# +int64ToWord64# = int2Word# + +word64ToInt64# :: Word# -> Int# +word64ToInt64# = word2Int# +#endif + +#if __GLASGOW_HASKELL__ < 902 +wordToWord8# :: Word# -> Word# +wordToWord8# x = x + +wordToWord16# :: Word# -> Word# +wordToWord16# x = x + +wordToWord32# :: Word# -> Word# +wordToWord32# x = x + +word8ToWord# :: Word# -> Word# +word8ToWord# x = x + +word16ToWord# :: Word# -> Word# +word16ToWord# x = x + +word32ToWord# :: Word# -> Word# +word32ToWord# x = x + +int8ToWord8# :: Int# -> Word# +int8ToWord8# = int2Word# + +int16ToWord16# :: Int# -> Word# +int16ToWord16# = int2Word# + +int32ToWord32# :: Int# -> Word# +int32ToWord32# = int2Word# + +word8ToInt8# :: Word# -> Int# +word8ToInt8# = word2Int# + +word16ToInt16# :: Word# -> Int# +word16ToInt16# = word2Int# + +word32ToInt32# :: Word# -> Int# +word32ToInt32# = word2Int# +#endif + diff --git a/src/Data/Array/Accelerate/Language.hs b/src/Data/Array/Accelerate/Language.hs index d32cd94a5..345ab23d3 100644 --- a/src/Data/Array/Accelerate/Language.hs +++ b/src/Data/Array/Accelerate/Language.hs @@ -17,13 +17,6 @@ -- Stability : experimental -- Portability : non-portable (GHC extensions) -- --- We use the dictionary view of overloaded operations (such as arithmetic and --- bit manipulation) to reify such expressions. With non-overloaded --- operations (such as, the logical connectives) and partially overloaded --- operations (such as comparisons), we use the standard operator names with a --- \'*\' attached. We keep the standard alphanumeric names as they can be --- easily qualified. --- module Data.Array.Accelerate.Language ( @@ -39,18 +32,6 @@ module Data.Array.Accelerate.Language ( -- * Map-like functions map, zipWith, - -- -- * Sequence collection - -- collect, - - -- -- * Sequence producers - -- streamIn, toSeq, - - -- -- * Sequence transducers - -- mapSeq, zipWithSeq, scanSeq, - - -- -- * Sequence consumers - -- foldSeq, foldSeqFlatten, - -- * Reductions fold, fold1, foldSeg', fold1Seg', @@ -67,7 +48,6 @@ module Data.Array.Accelerate.Language ( Boundary, Stencil, clamp, mirror, wrap, function, - -- ** Common stencil types Stencil3, Stencil5, Stencil7, Stencil9, Stencil3x3, Stencil5x3, Stencil3x5, Stencil5x5, @@ -96,12 +76,13 @@ module Data.Array.Accelerate.Language ( subtract, even, odd, gcd, lcm, (^), (^^), -- * Conversions - ord, chr, boolToInt, bitcast, + ord, chr, bitcast, ) where import Data.Array.Accelerate.AST ( PrimFun(..) ) -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Maybe +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Representation.Array ( ArrayR(..) ) import Data.Array.Accelerate.Representation.Shape ( ShapeR(..) ) import Data.Array.Accelerate.Representation.Type @@ -119,8 +100,8 @@ import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num import Data.Array.Accelerate.Classes.Ord -import Prelude ( ($), (.), Maybe(..), Char ) -#if __GLASGOW_HASKELL__ >= 904 +import Prelude ( ($), (.), Char ) +#if __GLASGOW_HASKELL__ == 904 import Data.Type.Equality #endif @@ -193,7 +174,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- ...we can replicate these elements to form a two-dimensional array either by -- replicating those elements as new rows: -- --- >>> run $ replicate (constant (Z :. (4::Int) :. All)) (use vec) +-- >>> run $ replicate @(Z :. Int :. All) (Z :. 4 :. All) (use vec) -- Matrix (Z :. 4 :. 10) -- [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, @@ -202,7 +183,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- -- ...or as columns: -- --- >>> run $ replicate (lift (Z :. All :. (4::Int))) (use vec) +-- >>> run $ replicate @(Z :. All :. Int) (Z :. All :. 4) (use vec) -- Matrix (Z :. 10 :. 4) -- [ 0, 0, 0, 0, -- 1, 1, 1, 1, @@ -218,7 +199,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- Replication along more than one dimension is also possible. Here we replicate -- twice across the first dimension and three times across the third dimension: -- --- >>> run $ replicate (constant (Z :. (2::Int) :. All :. (3::Int))) (use vec) +-- >>> run $ replicate @(Z :. Int :. All :. Int) (Z :. 2 :. All :. 3) (use vec) -- Array (Z :. 2 :. 10 :. 3) [0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7,8,8,8,9,9,9,0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7,8,8,8,9,9,9] -- -- The marker 'Any' can be used in the slice specification to match against some @@ -227,7 +208,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- -- >>> :{ -- let rep0 :: forall sh e. (Shape sh, Elt e) => Exp Int -> Acc (Array sh e) -> Acc (Array (sh :. Int) e) --- rep0 n a = replicate (lift (Any @sh :. n)) a +-- rep0 n a = replicate @(Any sh :. Int) (Any :. n) a -- :} -- -- >>> let x = unit 42 :: Acc (Scalar Int) @@ -251,7 +232,7 @@ unit (Exp e) = Acc $ SmartAcc $ Unit (eltR @e) e -- -- >>> :{ -- let rep1 :: forall sh e. (Shape sh, Elt e) => Exp Int -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int :. Int) e) --- rep1 n a = replicate (lift (Any @sh :. n :. All)) a +-- rep1 n a = replicate @(Any sh :. Int :. All) (Any :. n :. All) a -- :} -- -- >>> run $ rep1 5 (use vec) @@ -349,13 +330,13 @@ reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- ...will can select a specific row to yield a one dimensional result by fixing -- the row index (2) while allowing the column index to vary (via 'All'): -- --- >>> run $ slice (use mat) (constant (Z :. (2::Int) :. All)) +-- >>> run $ slice @(Z :. Int :. All) (use mat) (Z :. 2 :. All) -- Vector (Z :. 10) [20,21,22,23,24,25,26,27,28,29] -- -- A fully specified index (with no 'All's) returns a single element (zero -- dimensional array). -- --- >>> run $ slice (use mat) (constant (Z :. 4 :. 2 :: DIM2)) +-- >>> run $ slice @DIM2 (use mat) (Z :. 4 :. 2) -- Scalar Z [42] -- -- The marker 'Any' can be used in the slice specification to match against some @@ -365,7 +346,7 @@ reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- >>> :{ -- let -- sl0 :: forall sh e. (Shape sh, Elt e) => Acc (Array (sh:.Int) e) -> Exp Int -> Acc (Array sh e) --- sl0 a n = slice a (lift (Any @sh :. n)) +-- sl0 a n = slice @(Any sh :. Int) a (Any :. n) -- :} -- -- >>> let vec = fromList (Z:.10) [0..] :: Vector Int @@ -379,7 +360,7 @@ reshape = Acc $$ applyAcc (Reshape $ shapeR @sh) -- -- >>> :{ -- let sl1 :: forall sh e. (Shape sh, Elt e) => Acc (Array (sh:.Int:.Int) e) -> Exp Int -> Acc (Array (sh:.Int) e) --- sl1 a n = slice a (lift (Any @sh :. n :. All)) +-- sl1 a n = slice @(Any sh :. Int :. All) a (Any :. n :. All) -- :} -- -- >>> run $ sl1 (use mat) 4 @@ -520,8 +501,7 @@ zipWith = Acc $$$ applyAcc (ZipWith (eltR @a) (eltR @b) (eltR @c)) -- See also 'Data.Array.Accelerate.Data.Fold.Fold', which can be a useful way to -- compute multiple results from a single reduction. -- -fold :: forall sh a. - (Shape sh, Elt a) +fold :: forall sh a. (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) @@ -538,8 +518,7 @@ fold f (Exp x) = Acc . applyAcc (Fold (eltR @a) (unExpBinaryFunction f) (Just x) -- The first argument needs to be an /associative/ function to enable an -- efficient parallel implementation, but does not need to be commutative. -- -fold1 :: forall sh a. - (Shape sh, Elt a) +fold1 :: forall sh a. (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Acc (Array (sh:.Int) a) -> Acc (Array sh a) @@ -559,14 +538,13 @@ fold1 f = Acc . applyAcc (Fold (eltR @a) (unExpBinaryFunction f) Nothing) -- @since 1.3.0.0 -- foldSeg' - :: forall sh a i. - (Shape sh, Elt a, Elt i, IsIntegral i, i ~ EltR i) + :: forall sh a i. (Shape sh, Elt a, Elt i, IsSingleIntegral (EltR i)) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array (sh:.Int) a) -> Acc (Segments i) -> Acc (Array (sh:.Int) a) -foldSeg' f (Exp x) = Acc $$ applyAcc (FoldSeg (integralType @i) (eltR @a) (unExpBinaryFunction f) (Just x)) +foldSeg' f (Exp x) = Acc $$ applyAcc (FoldSeg (singleIntegralType @(EltR i)) (eltR @a) (unExpBinaryFunction f) (Just x)) -- | Variant of 'foldSeg'' that requires /all/ segments of the reduced -- array to be non-empty, and doesn't need a default value. The segment @@ -576,13 +554,12 @@ foldSeg' f (Exp x) = Acc $$ applyAcc (FoldSeg (integralType @i) (eltR @a) (unExp -- @since 1.3.0.0 -- fold1Seg' - :: forall sh a i. - (Shape sh, Elt a, Elt i, IsIntegral i, i ~ EltR i) + :: forall sh a i. (Shape sh, Elt a, Elt i, IsSingleIntegral (EltR i)) => (Exp a -> Exp a -> Exp a) -> Acc (Array (sh:.Int) a) -> Acc (Segments i) -> Acc (Array (sh:.Int) a) -fold1Seg' f = Acc $$ applyAcc (FoldSeg (integralType @i) (eltR @a) (unExpBinaryFunction f) Nothing) +fold1Seg' f = Acc $$ applyAcc (FoldSeg (singleIntegralType @(EltR i)) (eltR @a) (unExpBinaryFunction f) Nothing) -- Scan functions -- -------------- @@ -720,7 +697,7 @@ scanr1 f (Acc a) = Acc $ SmartAcc $ Scan RightToLeft (eltR @a) (unExpBinaryFunct -- let zeros = fill (constant (Z:.10)) 0 -- ones = fill (shape xs) 1 -- in --- permute (+) zeros (\ix -> Just_ (I1 (xs!ix))) ones +-- permute (+) zeros (\ix -> Just (I1 (xs!ix))) ones -- :} -- -- >>> let xs = fromList (Z :. 20) [0,0,1,2,1,1,2,4,8,3,4,9,8,3,2,5,5,3,1,2] :: Vector Int @@ -737,7 +714,7 @@ scanr1 f (Acc a) = Acc $ SmartAcc $ Scan RightToLeft (eltR @a) (unExpBinaryFunct -- let zeros = fill (I2 n n) 0 -- ones = fill (I1 n) 1 -- in --- permute const zeros (\(I1 i) -> Just_ (I2 i i)) ones +-- permute const zeros (\(I1 i) -> Just (I2 i i)) ones -- :} -- -- >>> run $ identity 5 :: Matrix Int @@ -1267,11 +1244,7 @@ awhile :: forall a. Arrays a -> (Acc a -> Acc a) -- ^ function to apply -> Acc a -- ^ initial value -> Acc a -awhile f = Acc $$ applyAcc $ Awhile (arraysR @a) (unAccFunction g) - where - -- FIXME: This should be a no-op! - g :: Acc a -> Acc (Scalar PrimBool) - g = map mkCoerce . f +awhile f = Acc $$ applyAcc $ Awhile (arraysR @a) (unAccFunction f) -- Shapes and indices @@ -1302,7 +1275,7 @@ intersect (Exp shx) (Exp shy) = Exp $ intersect' (shapeR @sh) shx shy intersect' (ShapeRsnoc shR) (unPair -> (xs, x)) (unPair -> (ys, y)) = SmartExp $ intersect' shR xs ys `Pair` - SmartExp (PrimApp (PrimMin singleType) $ SmartExp $ Pair x y) + SmartExp (PrimApp (PrimMin scalarType) $ SmartExp $ Pair x y) -- | Union of two shapes @@ -1315,7 +1288,7 @@ union (Exp shx) (Exp shy) = Exp $ union' (shapeR @sh) shx shy union' (ShapeRsnoc shR) (unPair -> (xs, x)) (unPair -> (ys, y)) = SmartExp $ union' shR xs ys `Pair` - SmartExp (PrimApp (PrimMax singleType) $ SmartExp $ Pair x y) + SmartExp (PrimApp (PrimMax scalarType) $ SmartExp $ Pair x y) -- Flow-control @@ -1344,7 +1317,8 @@ while :: forall e. Elt e while c f (Exp e) = mkExp $ While @(EltR e) (eltR @e) (mkCoerce' . unExp . c . Exp) - (unExp . f . Exp) e + (unExp . f . Exp) + e -- Array operations with a scalar result @@ -1493,18 +1467,12 @@ x ^^ n -- |Convert a character to an 'Int'. -- ord :: Exp Char -> Exp Int -ord = mkFromIntegral +ord = mkPrimUnary $ PrimFromIntegral integralType numType -- |Convert an 'Int' into a character. -- chr :: Exp Int -> Exp Char -chr = mkFromIntegral - --- |Convert a Boolean value to an 'Int', where 'False' turns into '0' and 'True' --- into '1'. --- -boolToInt :: Exp Bool -> Exp Int -boolToInt = mkFromIntegral . mkCoerce @_ @Word8 +chr = mkPrimUnary $ PrimFromIntegral integralType numType -- |Reinterpret a value as another type. The two representations must have the -- same bit size. diff --git a/src/Data/Array/Accelerate/Lift.hs b/src/Data/Array/Accelerate/Lift.hs index 27482a86f..539dd2823 100644 --- a/src/Data/Array/Accelerate/Lift.hs +++ b/src/Data/Array/Accelerate/Lift.hs @@ -35,14 +35,15 @@ module Data.Array.Accelerate.Lift ( ) where import Data.Array.Accelerate.AST.Idx -import Data.Array.Accelerate.Pattern +import Data.Array.Accelerate.Pattern.Shape import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Shape +import Data.Array.Accelerate.Sugar.Shape ( Shape, DIM1 ) import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra hiding ( Exp ) +import Foreign.C.Types -- | Lift a unary function into 'Exp'. @@ -75,17 +76,17 @@ lift3 f x y z = lift $ f (unlift x) (unlift y) (unlift z) -- | Lift a unary function to a computation over rank-1 indices. -- ilift1 :: (Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -ilift1 f = lift1 (\(Z:.i) -> Z :. f i) +ilift1 f (Z:.i) = Z :. f i -- | Lift a binary function to a computation over rank-1 indices. -- ilift2 :: (Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -ilift2 f = lift2 (\(Z:.i) (Z:.j) -> Z :. f i j) +ilift2 f (Z:.i) (Z:.j) = Z :. f i j -- | Lift a ternary function to a computation over rank-1 indices. -- ilift3 :: (Exp Int -> Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -ilift3 f = lift3 (\(Z:.i) (Z:.j) (Z:.k) -> Z :. f i j k) +ilift3 f (Z :. i) (Z :. j) (Z :. k) = Z :. f i j k -- | The class of types @e@ which can be lifted into @c@. @@ -148,28 +149,28 @@ instance Unlift Acc (Acc a) where instance Lift Exp Z where type Plain Z = Z - lift _ = Z_ + lift _ = Z instance Unlift Exp Z where unlift _ = Z instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Int) where type Plain (ix :. Int) = Plain ix :. Int - lift (ix :. i) = lift ix ::. lift i + lift (ix :. i) = lift ix :. lift i instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. All) where type Plain (ix :. All) = Plain ix :. All - lift (ix :. i) = lift ix ::. constant i + lift (ix :. i) = lift ix :. constant i instance (Elt e, Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Exp e) where type Plain (ix :. Exp e) = Plain ix :. e - lift (ix :. i) = lift ix ::. i + lift (ix :. i) = lift ix :. i instance {-# OVERLAPPABLE #-} (Elt e, Elt (Plain ix), Unlift Exp ix) => Unlift Exp (ix :. Exp e) where - unlift (ix ::. i) = unlift ix :. i + unlift (ix :. i) = unlift ix :. i instance {-# OVERLAPPABLE #-} (Elt e, Elt ix) => Unlift Exp (Exp ix :. Exp e) where - unlift (ix ::. i) = ix :. i + unlift (ix :. i) = ix :. i instance (Shape sh, Elt (Any sh)) => Lift Exp (Any sh) where type Plain (Any sh) = Any sh @@ -276,8 +277,9 @@ instance Lift Exp CDouble where instance Lift Exp Bool where type Plain Bool = Bool - lift True = Exp . SmartExp $ SmartExp (Const scalarType 1) `Pair` SmartExp Nil - lift False = Exp . SmartExp $ SmartExp (Const scalarType 0) `Pair` SmartExp Nil + lift = Exp . SmartExp . Const scalarType . Bit + -- lift True = Exp . SmartExp $ SmartExp (Const scalarType 1) `Pair` SmartExp Nil + -- lift False = Exp . SmartExp $ SmartExp (Const scalarType 0) `Pair` SmartExp Nil instance Lift Exp Char where type Plain Char = Char diff --git a/src/Data/Array/Accelerate/Orphans.hs b/src/Data/Array/Accelerate/Orphans.hs index 39a43d493..c2f6b1806 100644 --- a/src/Data/Array/Accelerate/Orphans.hs +++ b/src/Data/Array/Accelerate/Orphans.hs @@ -28,7 +28,6 @@ import Numeric.Half -- base --- deriving instance (Show a, Show b, Show c, Show d, Show e, Show f, Show g, Show h, Show i, Show j, Show k, Show l, Show m, Show n, Show o, Show p) => Show (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) @@ -47,6 +46,5 @@ deriving instance Generic (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) deriving instance Generic (Ratio a) -- primitive --- deriving instance Prim Half diff --git a/src/Data/Array/Accelerate/Pattern.hs b/src/Data/Array/Accelerate/Pattern.hs index e212c0869..3ff082de7 100644 --- a/src/Data/Array/Accelerate/Pattern.hs +++ b/src/Data/Array/Accelerate/Pattern.hs @@ -11,7 +11,6 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern @@ -26,28 +25,14 @@ module Data.Array.Accelerate.Pattern ( pattern Pattern, - pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, - pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, - pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, - - pattern Z_, pattern Ix, pattern (::.), pattern All_, pattern Any_, - pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, - pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, - - pattern V2, pattern V3, pattern V4, pattern V8, pattern V16, ) where import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Representation.Tag -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Shape -import Data.Array.Accelerate.Sugar.Vec -import Data.Array.Accelerate.Type -import Data.Primitive.Vec import Language.Haskell.TH.Extra hiding ( Exp, Match ) @@ -58,238 +43,91 @@ import Language.Haskell.TH.Extra hiding ( Exp pattern Pattern :: forall b a context. IsPattern context a b => b -> context a pattern Pattern vars <- (matcher @context -> vars) where Pattern = builder @context +{-# COMPLETE Pattern :: Exp #-} +{-# COMPLETE Pattern :: Acc #-} class IsPattern context a b where builder :: b -> context a matcher :: context a -> b -pattern Vector :: forall b a context. IsVector context a b => b -> context a -pattern Vector vars <- (vunpack @context -> vars) - where Vector = vpack @context - -class IsVector context a b where - vpack :: b -> context a - vunpack :: context a -> b - --- | Pattern synonyms for indices, which may be more convenient to use than --- 'Data.Array.Accelerate.Lift.lift' and --- 'Data.Array.Accelerate.Lift.unlift'. --- -pattern Z_ :: Exp DIM0 -pattern Z_ = Pattern Z -{-# COMPLETE Z_ #-} - -infixl 3 ::. -pattern (::.) :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b) -pattern a ::. b = Pattern (a :. b) -{-# COMPLETE (::.) #-} - -infixl 3 `Ix` -pattern Ix :: (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b) -pattern a `Ix` b = a ::. b -{-# COMPLETE Ix #-} - -pattern All_ :: Exp All -pattern All_ <- (const True -> True) - where All_ = constant All -{-# COMPLETE All_ #-} - -pattern Any_ :: (Shape sh, Elt (Any sh)) => Exp (Any sh) -pattern Any_ <- (const True -> True) - where Any_ = constant Any -{-# COMPLETE Any_ #-} - --- IsPattern instances for Shape nil and cons --- -instance IsPattern Exp Z Z where - builder _ = constant Z - matcher _ = Z - -instance (Elt a, Elt b) => IsPattern Exp (a :. b) (Exp a :. Exp b) where - builder (Exp a :. Exp b) = Exp $ SmartExp $ Pair a b - matcher (Exp t) = Exp (SmartExp $ Prj PairIdxLeft t) :. Exp (SmartExp $ Prj PairIdxRight t) - - -- IsPattern instances for up to 16-tuples (Acc and Exp). TH takes care of -- the (unremarkable) boilerplate for us. -- runQ $ do - let - -- Generate instance declarations for IsPattern of the form: - -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) - mkAccPattern :: Int -> Q [Dec] - mkAccPattern n = do - a <- newName "a" - let - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example - b = tupT (map (\t -> [t| Acc $(varT t)|]) xs) - -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b) - snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs - -- Constraints for the type class, consisting of Arrays constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. - context = tupT - $ [t| Arrays $(varT a) |] - : [t| ArraysR $(varT a) ~ $snoc |] - : map (\t -> [t| Arrays $(varT t)|]) xs - -- - get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] - get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) - -- - _x <- newName "_x" - [d| instance $context => IsPattern Acc $(varT a) $b where - builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = - Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) - matcher (Acc $(varP _x)) = - $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) - |] - - -- Generate instance declarations for IsPattern of the form: - -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) - mkExpPattern :: Int -> Q [Dec] - mkExpPattern n = do - a <- newName "a" - let - -- Type variables for the elements - xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - -- Variables for sub-pattern matches - ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] - tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms - -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example - b = tupT (map (\t -> [t| Exp $(varT t)|]) xs) - -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b) - snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs - -- Constraints for the type class, consisting of Elt constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. - context = tupT - $ [t| Elt $(varT a) |] - : [t| EltR $(varT a) ~ $snoc |] - : map (\t -> [t| Elt $(varT t)|]) xs - -- - get x 0 = [| SmartExp (Prj PairIdxRight $x) |] - get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) - -- - _x <- newName "_x" - _y <- newName "_y" - [d| instance $context => IsPattern Exp $(varT a) $b where - builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = - let _unmatch :: SmartExp a -> SmartExp a - _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) - _unmatch x = x - in - Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs) - matcher (Exp $(varP _x)) = - case $(varE _x) of - SmartExp (Match $tags $(varP _y)) - -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) - _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) - |] - - -- Generate instance declarations for IsVector of the form: - -- instance (Elt v, EltR v ~ Vec 2 a, Elt a) => IsVector Exp v (Exp a, Exp a) - mkVecPattern :: Int -> Q [Dec] - mkVecPattern n = do - a <- newName "a" - v <- newName "v" - let - -- Last argument to `IsVector`, eg (Exp, a, Exp a) in the example - tup = tupT (replicate n ([t| Exp $(varT a)|])) - -- Representation as a vector, eg (Vec 2 a) - vec = [t| Vec $(litT (numTyLit (fromIntegral n))) $(varT a) |] - -- Constraints for the type class, consisting of Elt constraints on all type variables, - -- and an equality constraint on the representation type of `a` and the vector representation `vec`. - context = [t| (Elt $(varT v), VecElt $(varT a), EltR $(varT v) ~ $vec) |] - -- - vecR = foldr appE ([| VecRnil |] `appE` (varE 'singleType `appTypeE` varT a)) (replicate n [| VecRsucc |]) - tR = tupT (replicate n (varT a)) - -- - [d| instance $context => IsVector Exp $(varT v) $tup where - vpack x = case builder x :: Exp $tR of - Exp x' -> Exp (SmartExp (VecPack $vecR x')) - vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR) - |] - -- - es <- mapM mkExpPattern [0..16] - as <- mapM mkAccPattern [0..16] - vs <- mapM mkVecPattern [2,3,4,8,16] - return $ concat (es ++ as ++ vs) - - --- | Specialised pattern synonyms for tuples, which may be more convenient to --- use than 'Data.Array.Accelerate.Lift.lift' and --- 'Data.Array.Accelerate.Lift.unlift'. For example, to construct a pair: --- --- > let a = 4 :: Exp Int --- > let b = 2 :: Exp Float --- > let c = T2 a b -- :: Exp (Int, Float); equivalent to 'lift (a,b)' --- --- Similarly they can be used to destruct values: --- --- > let T2 x y = c -- x :: Exp Int, y :: Exp Float; equivalent to 'let (x,y) = unlift c' --- --- These pattern synonyms can be used for both 'Exp' and 'Acc' terms. --- --- Similarly, we have patterns for constructing and destructing indices of --- a given dimensionality: --- --- > let ix = Ix 2 3 -- :: Exp DIM2 --- > let I2 y x = ix -- y :: Exp Int, x :: Exp Int --- -runQ $ do - let - mkT :: Int -> Q [Dec] - mkT n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - name = mkName ('T':show n) - con = varT (mkName "con") - ty1 = tupT ts - ty2 = tupT (map (con `appT`) ts) - sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts - in - sequence - [ patSynSigD name [t| IsPattern $con $ty1 $ty2 => $sig |] - , patSynD name (prefixPatSyn xs) implBidir [p| Pattern $(tupP (map varP xs)) |] - , pragCompleteD [name] (Just ''Acc) - , pragCompleteD [name] (Just ''Exp) - ] - - mkI :: Int -> Q [Dec] - mkI n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - name = mkName ('I':show n) - ix = mkName "Ix" - cst = tupT (map (\t -> [t| Elt $t |]) ts) - dim = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts - sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts - in - sequence - [ patSynSigD name [t| $cst => $sig |] - , patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z_ |] xs) - , pragCompleteD [name] Nothing - ] - - mkV :: Int -> Q [Dec] - mkV n = - let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] - ts = map varT xs - name = mkName ('V':show n) - con = varT (mkName "con") - ty1 = varT (mkName "vec") - ty2 = tupT (map (con `appT`) ts) - sig = foldr (\t r -> [t| $con $t -> $r |]) (appT con ty1) ts - in - sequence - [ patSynSigD name [t| IsVector $con $ty1 $ty2 => $sig |] - , patSynD name (prefixPatSyn xs) implBidir [p| Vector $(tupP (map varP xs)) |] - , pragCompleteD [name] (Just ''Exp) - ] - -- - ts <- mapM mkT [2..16] - is <- mapM mkI [0..9] - vs <- mapM mkV [2,3,4,8,16] - return $ concat (ts ++ is ++ vs) + let + -- Generate instance declarations for IsPattern of the form: + -- instance (Arrays x, ArraysR x ~ (((), ArraysR a), ArraysR b), Arrays a, Arrays b,) => IsPattern Acc x (Acc a, Acc b) + mkAccPattern :: Int -> Q [Dec] + mkAccPattern n = do + a <- newName "a" + _x <- newName "_x" + let + -- Type variables for the elements + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- Last argument to `IsPattern`, eg (Acc a, Acc b) in the example + b = tupT (map (\t -> [t| Acc $(varT t)|]) xs) + -- Representation as snoc-list of pairs, eg (((), ArraysR a), ArraysR b) + snoc = foldl (\sn t -> [t| ($sn, ArraysR $(varT t)) |]) [t| () |] xs + -- Constraints for the type class, consisting of Arrays constraints on all type variables, + -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. + context = tupT + $ [t| Arrays $(varT a) |] + : [t| ArraysR $(varT a) ~ $snoc |] + : map (\t -> [t| Arrays $(varT t)|]) xs + -- + get x 0 = [| Acc (SmartAcc (Aprj PairIdxRight $x)) |] + get x i = get [| SmartAcc (Aprj PairIdxLeft $x) |] (i-1) + -- + [d| instance $context => IsPattern Acc $(varT a) $b where + builder $(tupP (map (\x -> [p| Acc $(varP x)|]) xs)) = + Acc $(foldl (\vs v -> [| SmartAcc ($vs `Apair` $(varE v)) |]) [| SmartAcc Anil |] xs) + matcher (Acc $(varP _x)) = + $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0])) + |] + + -- Generate instance declarations for IsPattern of the form: + -- instance (Elt x, EltR x ~ (((), EltR a), EltR b), Elt a, Elt b,) => IsPattern Exp x (Exp a, Exp b) + mkExpPattern :: Int -> Q [Dec] + mkExpPattern n = do + a <- newName "a" + _x <- newName "_x" + _y <- newName "_y" + let + -- Type variables for the elements + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- Variables for sub-pattern matches + ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] + tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms + -- Last argument to `IsPattern`, eg (Exp, a, Exp b) in the example + b = tupT (map (\t -> [t| Exp $(varT t)|]) xs) + -- Representation as snoc-list of pairs, eg (((), EltR a), EltR b) + snoc = foldl (\sn t -> [t| ($sn, EltR $(varT t)) |]) [t| () |] xs + -- Constraints for the type class, consisting of Elt constraints on all type variables, + -- and an equality constraint on the representation type of `a` and the snoc representation `snoc`. + context = tupT + $ [t| Elt $(varT a) |] + : [t| EltR $(varT a) ~ $snoc |] + : map (\t -> [t| Elt $(varT t)|]) xs + -- + get x 0 = [| SmartExp (Prj PairIdxRight $x) |] + get x i = get [| SmartExp (Prj PairIdxLeft $x) |] (i-1) + -- + [d| instance $context => IsPattern Exp $(varT a) $b where + builder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = + let _unmatch :: SmartExp a -> SmartExp a + _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) + _unmatch x = x + in + Exp $(foldl (\vs v -> [| SmartExp ($vs `Pair` _unmatch $(varE v)) |]) [| SmartExp Nil |] xs) + matcher (Exp $(varP _x)) = + case $(varE _x) of + SmartExp (Match $tags $(varP _y)) + -> $(tupE [[| Exp (SmartExp (Match $(varE m) $(get (varE _x) i))) |] | m <- ms | i <- [(n-1), (n-2) .. 0]]) + _ -> $(tupE [[| Exp $(get (varE _x) i) |] | i <- [(n-1), (n-2) .. 0]]) + |] + -- + es <- mapM mkExpPattern [0..16] + as <- mapM mkAccPattern [0..16] + return $ concat (es ++ as) diff --git a/src/Data/Array/Accelerate/Pattern/Bool.hs b/src/Data/Array/Accelerate/Pattern/Bool.hs index d968aaf34..d3dfbb0b0 100644 --- a/src/Data/Array/Accelerate/Pattern/Bool.hs +++ b/src/Data/Array/Accelerate/Pattern/Bool.hs @@ -1,9 +1,7 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern.Bool -- Copyright : [2018..2020] The Accelerate Team @@ -16,11 +14,98 @@ module Data.Array.Accelerate.Pattern.Bool ( - Bool, pattern True_, pattern False_, + Bool, pattern True, pattern False, ) where -import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Representation.Tag +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Type -mkPattern ''Bool +import Data.Bool ( Bool ) +import Prelude hiding ( Bool(..) ) +import qualified Prelude as P + +import GHC.Stack + + +{-# COMPLETE False, True :: Exp #-} +{-# COMPLETE False, True :: Bool #-} +pattern False :: (HasCallStack, IsFalse r) => r +pattern False <- (matchFalse -> Just ()) + where False = buildFalse + +pattern True :: (HasCallStack, IsTrue r) => r +pattern True <- (matchTrue -> Just ()) + where True = buildTrue + +class IsFalse r where + buildFalse :: r + matchFalse :: r -> Maybe () + +instance IsFalse Bool where + buildFalse = P.False + matchFalse P.False = Just () + matchFalse _ = Nothing + +instance IsFalse (Exp Bool) where + buildFalse = _buildFalse + matchFalse = _matchFalse + +class IsTrue r where + buildTrue :: r + matchTrue :: r -> Maybe () + +instance IsTrue Bool where + buildTrue = P.True + matchTrue P.True = Just () + matchTrue _ = Nothing + +instance IsTrue (Exp Bool) where + buildTrue = _buildTrue + matchTrue = _matchTrue + + + +_buildFalse :: Exp Bool +_buildFalse = mkExp $ Const scalarType 0 + +_matchFalse :: HasCallStack => Exp Bool -> Maybe () +_matchFalse (Exp e) = + case e of + SmartExp (Match (TagRenum TagBit 0) _) -> Just () + SmartExp Match{} -> Nothing + _ -> error $ unlines + [ "Embedded pattern synonym used outside 'match' context." + , "" + , "To use case statements in the embedded language the case statement must" + , "be applied as an n-ary function to the 'match' operator. For single" + , "argument case statements this can be done inline using LambdaCase, for" + , "example:" + , "" + , "> x & match \\case" + , "> False_ -> ..." + , "> _ -> ..." + ] + +_buildTrue :: Exp Bool +_buildTrue = mkExp $ Const scalarType 1 + +_matchTrue :: HasCallStack => Exp Bool -> Maybe () +_matchTrue (Exp e) = + case e of + SmartExp (Match (TagRenum TagBit 1) _) -> Just () + SmartExp Match{} -> Nothing + _ -> error $ unlines + [ "Embedded pattern synonym used outside 'match' context." + , "" + , "To use case statements in the embedded language the case statement must" + , "be applied as an n-ary function to the 'match' operator. For single" + , "argument case statements this can be done inline using LambdaCase, for" + , "example:" + , "" + , "> x & match \\case" + , "> True_ -> ..." + , "> _ -> ..." + ] diff --git a/src/Data/Array/Accelerate/Pattern/Either.hs b/src/Data/Array/Accelerate/Pattern/Either.hs index 67c7b3a3f..e94ab74a9 100644 --- a/src/Data/Array/Accelerate/Pattern/Either.hs +++ b/src/Data/Array/Accelerate/Pattern/Either.hs @@ -1,9 +1,11 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern.Either -- Copyright : [2018..2020] The Accelerate Team @@ -16,11 +18,68 @@ module Data.Array.Accelerate.Pattern.Either ( - Either, pattern Left_, pattern Right_, + Either, pattern Left, pattern Right, ) where import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Smart -mkPattern ''Either +import Data.Either ( Either ) +import Language.Haskell.TH.Extra hiding ( Exp ) +import Prelude hiding ( Either(..) ) +import qualified Data.List as P +import qualified Prelude as P + + +runQ $ do + let it SigD{} = True + it FunD{} = True + it _ = False + + find _ [] = error "could not find specified function" + find pat (d:ds) = + case d of + SigD n _ | pat `P.isPrefixOf` nameBase n -> varE n + _ -> find pat ds + + decs <- filter it <$> mkPattern ''Either + rest <- [d| {-# COMPLETE Left, Right :: Exp #-} + {-# COMPLETE Left, Right :: Either #-} + pattern Left :: IsLeft a b r => a -> r + pattern Left x <- (matchLeft -> P.Just x) + where Left = buildLeft + + class IsLeft a b r | r -> a b where + buildLeft :: a -> r + matchLeft :: r -> Maybe a + + instance IsLeft a b (Either a b) where + buildLeft = P.Left + matchLeft (P.Left a) = P.Just a + matchLeft _ = P.Nothing + + instance (Elt a, Elt b) => IsLeft (Exp a) (Exp b) (Exp (Either a b)) where + buildLeft = $(find "_buildLeft" decs) + matchLeft = $(find "_matchLeft" decs) + + pattern Right :: IsRight a b r => b -> r + pattern Right x <- (matchRight -> P.Just x) + where Right = buildRight + + class IsRight a b r | r -> a b where + buildRight :: b -> r + matchRight :: r -> Maybe b + + instance IsRight a b (Either a b) where + buildRight = P.Right + matchRight (P.Right b) = P.Just b + matchRight _ = P.Nothing + + instance (Elt a, Elt b) => IsRight (Exp a) (Exp b) (Exp (Either a b)) where + buildRight = $(find "_buildRight" decs) + matchRight = $(find "_matchRight" decs) + |] + return (decs ++ rest) diff --git a/src/Data/Array/Accelerate/Pattern/Maybe.hs b/src/Data/Array/Accelerate/Pattern/Maybe.hs index 67e341d64..f7a7395bb 100644 --- a/src/Data/Array/Accelerate/Pattern/Maybe.hs +++ b/src/Data/Array/Accelerate/Pattern/Maybe.hs @@ -1,9 +1,12 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Pattern.Maybe -- Copyright : [2018..2020] The Accelerate Team @@ -16,11 +19,71 @@ module Data.Array.Accelerate.Pattern.Maybe ( - Maybe, pattern Nothing_, pattern Just_, + Maybe, pattern Nothing, pattern Just, ) where import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Smart -mkPattern ''Maybe +import Data.Maybe ( Maybe ) +import Language.Haskell.TH.Extra hiding ( Exp ) +import Prelude hiding ( Maybe(..) ) +import qualified Data.List as P +import qualified Prelude as P + +import GHC.Stack + + +-- TODO: We should make this a feature of the mkPattern machinery +-- +runQ $ do + let it SigD{} = True + it FunD{} = True + it _ = False + + find _ [] = error "could not find specified function" + find pat (d:ds) = + case d of + SigD n _ | pat `P.isPrefixOf` nameBase n -> varE n + _ -> find pat ds + + decs <- filter it <$> mkPattern ''Maybe + rest <- [d| {-# COMPLETE Nothing, Just :: Exp #-} + {-# COMPLETE Nothing, Just :: Maybe #-} + pattern Nothing :: (HasCallStack, IsNothing r) => r + pattern Nothing <- (matchNothing -> P.Just ()) + where Nothing = buildNothing + + pattern Just :: (HasCallStack, IsJust a r) => a -> r + pattern Just x <- (matchJust -> P.Just x) + where Just = buildJust + + class IsNothing r where + buildNothing :: r + matchNothing :: r -> Maybe () + + instance IsNothing (Maybe a) where + buildNothing = P.Nothing + matchNothing P.Nothing = P.Just () + matchNothing _ = P.Nothing + + instance Elt a => IsNothing (Exp (Maybe a)) where + buildNothing = $(find "_buildNothing" decs) + matchNothing = $(find "_matchNothing" decs) + + class IsJust a r | r -> a where + buildJust :: a -> r + matchJust :: r -> Maybe a + + instance IsJust a (Maybe a) where + buildJust = P.Just + matchJust = P.id + + instance Elt a => IsJust (Exp a) (Exp (Maybe a)) where + buildJust = $(find "_buildJust" decs) + matchJust = $(find "_matchJust" decs) + |] + return (decs ++ rest) diff --git a/src/Data/Array/Accelerate/Pattern/Ordering.hs b/src/Data/Array/Accelerate/Pattern/Ordering.hs index 2407cf9e9..41441733a 100644 --- a/src/Data/Array/Accelerate/Pattern/Ordering.hs +++ b/src/Data/Array/Accelerate/Pattern/Ordering.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -16,11 +17,85 @@ module Data.Array.Accelerate.Pattern.Ordering ( - Ordering, pattern LT_, pattern EQ_, pattern GT_, + Ordering, pattern LT, pattern EQ, pattern GT, ) where import Data.Array.Accelerate.Pattern.TH +import Data.Array.Accelerate.Smart -mkPattern ''Ordering +import Data.Ord ( Ordering ) +import Language.Haskell.TH.Extra hiding ( Exp ) +import Prelude hiding ( Ordering(..) ) +import qualified Data.List as P +import qualified Prelude as P + + +runQ $ do + let it SigD{} = True + it FunD{} = True + it _ = False + + find _ [] = error "could not find specified function" + find pat (d:ds) = + case d of + SigD n _ | pat `P.isPrefixOf` nameBase n -> varE n + _ -> find pat ds + + decs <- filter it <$> mkPattern ''Ordering + rest <- [d| {-# COMPLETE LT, EQ, GT :: Exp #-} + {-# COMPLETE LT, EQ, GT :: Ordering #-} + pattern LT :: IsLT a => a + pattern LT <- (matchLT -> Just ()) + where LT = buildLT + + class IsLT a where + buildLT :: a + matchLT :: a -> Maybe () + + instance IsLT Ordering where + buildLT = P.LT + matchLT P.LT = Just () + matchLT _ = Nothing + + instance IsLT (Exp Ordering) where + buildLT = $(find "_buildLT" decs) + matchLT = $(find "_matchLT" decs) + + pattern EQ :: IsEQ a => a + pattern EQ <- (matchEQ -> Just ()) + where EQ = buildEQ + + class IsEQ a where + buildEQ :: a + matchEQ :: a -> Maybe () + + instance IsEQ Ordering where + buildEQ = P.EQ + matchEQ P.EQ = Just () + matchEQ _ = Nothing + + instance IsEQ (Exp Ordering) where + buildEQ = $(find "_buildEQ" decs) + matchEQ = $(find "_matchEQ" decs) + + pattern GT :: IsGT a => a + pattern GT <- (matchGT -> Just ()) + where GT = buildGT + + class IsGT a where + buildGT :: a + matchGT :: a -> Maybe () + + instance IsGT Ordering where + buildGT = P.GT + matchGT P.GT = Just () + matchGT _ = Nothing + + instance IsGT (Exp Ordering) where + buildGT = $(find "_buildGT" decs) + matchGT = $(find "_matchGT" decs) + + |] + return (decs ++ rest) diff --git a/src/Data/Array/Accelerate/Pattern/SIMD.hs b/src/Data/Array/Accelerate/Pattern/SIMD.hs new file mode 100644 index 000000000..429f28864 --- /dev/null +++ b/src/Data/Array/Accelerate/Pattern/SIMD.hs @@ -0,0 +1,168 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +-- | +-- Module : Data.Array.Accelerate.Pattern.SIMD +-- Copyright : [2018..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Pattern.SIMD ( + + pattern SIMD, + pattern V2, pattern V3, pattern V4, pattern V8, pattern V16, + +) where + +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Type + +import Language.Haskell.TH.Extra hiding ( Exp, Match ) +import GHC.Exts ( IsList(..) ) + + +pattern SIMD :: forall b a context. IsSIMD context a b => b -> context a +pattern SIMD vars <- (vmatcher @context -> vars) + where SIMD = vbuilder @context +{-# COMPLETE SIMD :: Exp #-} + +class IsSIMD context a b where + vbuilder :: b -> context a + vmatcher :: context a -> b + + +runQ $ + let + -- Generate instance declarations for IsSIMD of the form: + -- instance (Elt a, Elt v, EltR v ~ VecR n a) => IsSIMD Exp v (Exp a, Exp a) + mkV :: Int -> Q [Dec] + mkV n = do + a <- newName "a" + v <- newName "v" + _x <- newName "_x" + _y <- newName "_y" + let + aT = varT a + vT = varT v + nT = litT (numTyLit (toInteger n)) + -- Last argument to `IsSIMD`, eg (Exp, a, Exp a) in the example + tup = tupT (replicate n ([t| Exp $aT |])) + -- Constraints for the type class, consisting of the Elt + -- constraints and the equality on representation types + context = [t| (Elt $aT, Elt $vT, SIMD $nT $aT, EltR $vT ~ VecR $nT $aT) |] + -- Type variables for the elements + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + -- Variables for sub-pattern matches + -- ms = [ mkName ('m' : show i) | i <- [0 .. n-1] ] + -- tags = foldl (\ts t -> [p| $ts `TagRpair` $(varP t) |]) [p| TagRunit |] ms + -- + [d| instance $context => IsSIMD Exp $vT $tup where + vbuilder $(tupP (map (\x -> [p| Exp $(varP x)|]) xs)) = + let _unmatch :: SmartExp a -> SmartExp a + _unmatch (SmartExp (Match _ $(varP _y))) = $(varE _y) + _unmatch x = x + in + Exp $(foldl (\vs (i, x) -> [| mkInsert + $(varE 'vecR `appTypeE` nT `appTypeE` aT) + $(varE 'eltR `appTypeE` aT) + TypeWord8 + $vs + (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i)) + (_unmatch $(varE x)) + |]) + [| unExp (undef :: Exp (Vec $nT $aT)) |] + (zip [0 .. n-1] xs) + ) + + vmatcher (Exp $(varP _x)) = + case $(varE _x) of + -- SmartExp (Match $tags $(varP _y)) + -- -> $(tupE [[| Exp (SmartExp (Match $(varE m) (unExp (extract (Exp $(varE _x) :: Exp $vec) (constant (i :: Word8)))))) |] | m <- ms | i <- [0 .. n-1]]) + -- -> $(tupE [[| Exp (SmartExp (Match $(varE m) (mkExtract + -- $(varE 'vecR `appTypeE` nT `appTypeE` aT) + -- $(varE 'eltR `appTypeE` aT) + -- TypeWord8 + -- $(varE _x) + -- (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i))))) |] + -- | m <- ms + -- | i <- [0 .. n-1] ]) + + _ -> $(tupE [[| Exp $ mkExtract + $(varE 'vecR `appTypeE` nT `appTypeE` aT) + $(varE 'eltR `appTypeE` aT) + TypeWord8 + $(varE _x) + (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) i)) + |] + | i <- [0 .. n-1] ]) + |] + in + concat <$> mapM mkV [2,3,4,8,16] + +-- Generate polymorphic pattern synonyms which operate on both Haskell values +-- as well as embedded expressions +-- +runQ $ + let + mkV :: Int -> Q [Dec] + mkV n = do + a <- newName "a" + v <- newName "v" + let + as = replicate n (varT a) + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + xsP = map varP xs + xsE = map varE xs + name = mkName ("V" ++ show n) + isV = mkName ("IsV" ++ show n) + builder = mkName ("buildV" ++ show n) + matcher = mkName ("matchV" ++ show n) + ctx = return [ ConT ''Elt `AppT` VarT a + , ConT ''SIMD `AppT` LitT (NumTyLit (toInteger n)) `AppT` VarT a + ] + ctx' = return [ ConT ''Elt `AppT` VarT a + , ConT ''SIMD `AppT` LitT (NumTyLit (toInteger n)) `AppT` VarT a + , EqualityT `AppT` VarT v `AppT` (ConT name `AppT` VarT a) + ] + -- + sequence + [ patSynSigD name [t| $(conT isV) $(varT a) $(varT v) => $(foldr (\t r -> [t| $t -> $r |]) (varT v) as) |] + , patSynD name (prefixPatSyn xs) (explBidir [clause [] (normalB (varE builder)) []]) (parensP $ viewP (varE matcher) (tupP xsP)) + , pragCompleteD [name] (Just ''Vec) + , pragCompleteD [name] (Just ''Exp) + -- + , classD (return []) isV [plainTV a, plainTV v] [funDep [v] [a]] + [ sigD builder (foldr (\t r -> [t| $t -> $r |]) (varT v) as) + , sigD matcher [t| $(varT v) -> $(tupT as) |] + ] + -- This instance which goes via toList is horrible and I feel bad for using it + -- TLM 2022-06-27 + , instanceD ctx [t| $(conT isV) $(varT a) ($(conT name) $(varT a)) |] + [ funD builder [ clause xsP (normalB [| fromList $(listE xsE) |]) []] + , funD matcher [ clause [viewP (varE 'toList) (listP xsP)] (normalB (tupE xsE)) [] ] + ] + , instanceD ctx' [t| $(conT isV) (Exp $(varT a)) (Exp $(varT v)) |] + [ funD builder [ clause xsP (normalB [| SIMD $(tupE xsE) |]) []] + , funD matcher [ clause [conP (mkName "SIMD") [tupP xsP]] (normalB (tupE xsE)) [] ] + ] + ] + in + concat <$> mapM mkV [2,3,4,8,16] + diff --git a/src/Data/Array/Accelerate/Pattern/Shape.hs b/src/Data/Array/Accelerate/Pattern/Shape.hs new file mode 100644 index 000000000..8f4c1beee --- /dev/null +++ b/src/Data/Array/Accelerate/Pattern/Shape.hs @@ -0,0 +1,141 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +-- | +-- Module : Data.Array.Accelerate.Pattern.Shape +-- Copyright : [2018..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Pattern.Shape ( + + pattern Z, Z, + pattern (:.), (:.), + pattern All, All, + pattern Any, Any, + + pattern I0, pattern I1, pattern I2, pattern I3, pattern I4, + pattern I5, pattern I6, pattern I7, pattern I8, pattern I9, + +) where + +import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Shape ( (:.), Z, All, Any ) +import qualified Data.Array.Accelerate.Sugar.Shape as Sugar + +import Language.Haskell.TH.Extra hiding ( Exp, Match ) + + +pattern Z :: IsShapeZ z => z +pattern Z <- (matchZ -> True) + where Z = buildZ +{-# COMPLETE Z :: Z #-} +{-# COMPLETE Z :: Exp #-} + +pattern All :: IsShapeAll all => all +pattern All <- (matchAll -> True) + where All = buildAll +{-# COMPLETE All :: All #-} +{-# COMPLETE All :: Exp #-} + +pattern Any :: IsShapeAny any => any +pattern Any <- (matchAny -> True) + where Any = buildAny +{-# COMPLETE Any :: Any #-} +{-# COMPLETE Any :: Exp #-} + +infixl 3 :. +pattern (:.) :: IsShapeSnoc t h s => t -> h -> s +pattern t :. h <- (matchSnoc -> (t Sugar.:. h)) + where t :. h = buildSnoc t h +{-# COMPLETE (:.) :: (:.) #-} +{-# COMPLETE (:.) :: Exp #-} + + +class IsShapeZ z where + matchZ :: z -> Bool + buildZ :: z + +instance IsShapeZ Z where + matchZ _ = True + buildZ = Sugar.Z + +instance IsShapeZ (Exp Z) where + matchZ _ = True + buildZ = constant Sugar.Z + +class IsShapeAll all where + matchAll :: all -> Bool + buildAll :: all + +instance IsShapeAll All where + matchAll _ = True + buildAll = Sugar.All + +instance IsShapeAll (Exp All) where + matchAll _ = True + buildAll = constant Sugar.All + +class IsShapeAny any where + matchAny :: any -> Bool + buildAny :: any + +instance IsShapeAny (Any sh) where + matchAny _ = True + buildAny = Sugar.Any + +instance Elt (Any sh) => IsShapeAny (Exp (Any sh)) where + matchAny _ = True + buildAny = constant Sugar.Any + +class IsShapeSnoc t h s | s -> h t where + matchSnoc :: s -> (t :. h) + buildSnoc :: t -> h -> s + +instance IsShapeSnoc (Exp t) (Exp h) (Exp (t :. h)) where + buildSnoc (Exp a) (Exp b) = Exp $ SmartExp $ Pair a b + matchSnoc (Exp t) = Exp (SmartExp $ Prj PairIdxLeft t) Sugar.:. Exp (SmartExp $ Prj PairIdxRight t) + +instance IsShapeSnoc t h (t :. h) where + buildSnoc = (Sugar.:.) + matchSnoc = id + + +-- Generate patterns for constructing and destructing indices of a given +-- dimensionality: +-- +-- > let ix = Ix 2 3 -- :: Exp DIM2 +-- > let I2 y x = ix -- y :: Exp Int, x :: Exp Int +-- +runQ $ + let + mkI :: Int -> Q [Dec] + mkI n = + let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + name = mkName ('I':show n) + ix = mkName ":." + cst = tupT (map (\t -> [t| Elt $t |]) ts) + dim = foldl (\h t -> [t| $h :. $t |]) [t| Z |] ts + sig = foldr (\t r -> [t| Exp $t -> $r |]) [t| Exp $dim |] ts + in + sequence + [ patSynSigD name [t| $cst => $sig |] + , patSynD name (prefixPatSyn xs) implBidir (foldl (\ps p -> infixP ps ix (varP p)) [p| Z |] xs) + , pragCompleteD [name] Nothing + ] + in + concat <$> mapM mkI [0..9] + diff --git a/src/Data/Array/Accelerate/Pattern/TH.hs b/src/Data/Array/Accelerate/Pattern/TH.hs index 0323f8d1a..6a3f8303e 100644 --- a/src/Data/Array/Accelerate/Pattern/TH.hs +++ b/src/Data/Array/Accelerate/Pattern/TH.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} -- | @@ -12,8 +13,9 @@ module Data.Array.Accelerate.Pattern.TH ( - mkPattern, - mkPatterns, + Options(..), defaultOptions, + mkPattern, mkPatternWith, + mkPatterns, mkPatternsWith, ) where @@ -36,10 +38,24 @@ import qualified Language.Haskell.TH.Extra as TH import GHC.Stack +-- | Options to control what is generated +-- +data Options = Options + { renameNormalC :: Name -> Name + , renameInfixC :: Name -> Name + } + +defaultOptions :: Options +defaultOptions = Options defaultRenameNormalC defaultRenameInfixC + + -- | As 'mkPattern', but for a list of types -- mkPatterns :: [Name] -> DecsQ -mkPatterns nms = concat <$> mapM mkPattern nms +mkPatterns = mkPatternsWith defaultOptions + +mkPatternsWith :: Options -> [Name] -> DecsQ +mkPatternsWith opts nms = concat <$> mapM (mkPatternWith opts) nms -- | Generate pattern synonyms for the given simple (Haskell'98) sum or -- product data type. @@ -61,24 +77,27 @@ mkPatterns nms = concat <$> mapM mkPattern nms -- > ycoord :: Exp Point -> Exp Float -- mkPattern :: Name -> DecsQ -mkPattern nm = do +mkPattern = mkPatternWith defaultOptions + +mkPatternWith :: Options -> Name -> DecsQ +mkPatternWith opt nm = do info <- reify nm case info of - TyConI dec -> mkDec dec + TyConI dec -> mkDec opt dec _ -> fail "mkPatterns: expected the name of a newtype or datatype" -mkDec :: Dec -> DecsQ -mkDec dec = +mkDec :: Options -> Dec -> DecsQ +mkDec opt dec = case dec of - DataD _ nm tv _ cs _ -> mkDataD nm tv cs - NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c + DataD _ nm tv _ cs _ -> mkDataD opt nm tv cs + NewtypeD _ nm tv _ c _ -> mkNewtypeD opt nm tv c _ -> fail "mkPatterns: expected the name of a newtype or datatype" -mkNewtypeD :: Name -> [TyVarBndr ()] -> Con -> DecsQ -mkNewtypeD tn tvs c = mkDataD tn tvs [c] +mkNewtypeD :: Options -> Name -> [TyVarBndr ()] -> Con -> DecsQ +mkNewtypeD opt tn tvs c = mkDataD opt tn tvs [c] -mkDataD :: Name -> [TyVarBndr ()] -> [Con] -> DecsQ -mkDataD tn tvs cs = do +mkDataD :: Options -> Name -> [TyVarBndr ()] -> [Con] -> DecsQ +mkDataD opt tn tvs cs = do (pats, decs) <- unzip <$> go cs comp <- pragCompleteD pats Nothing return $ comp : concat decs @@ -86,14 +105,14 @@ mkDataD tn tvs cs = do -- For single-constructor types we create the pattern synonym for the -- type directly in terms of Pattern go [] = fail "mkPatterns: empty data declarations not supported" - go [c] = return <$> mkConP tn tvs c + go [c] = return <$> mkConP opt tn tvs c go _ = go' [] (map fieldTys cs) ctags cs -- For sum-types, when creating the pattern for an individual -- constructor we need to know about the types of the fields all other -- constructors as well go' prev (this:next) (tag:tags) (con:cons) = do - r <- mkConS tn tvs prev next tag con + r <- mkConS opt tn tvs prev next tag con rs <- go' (this:prev) next tags cons return (r : rs) go' _ [] [] [] = return [] @@ -122,12 +141,12 @@ mkDataD tn tvs cs = do map bitsToTag (l ++ r) -mkConP :: Name -> [TyVarBndr ()] -> Con -> Q (Name, [Dec]) -mkConP tn' tvs' con' = do +mkConP :: Options -> Name -> [TyVarBndr ()] -> Con -> Q (Name, [Dec]) +mkConP Options{..} tn' tvs' con' = do checkExts [ PatternSynonyms ] case con' of NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs) - RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs) + RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (renameNormalC . fst3) fs) (map thd3 fs) InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b] _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" where @@ -142,7 +161,7 @@ mkConP tn' tvs' con' = do ] return (pat, r) where - pat = rename cn + pat = renameNormalC cn sig = forallT (map (`plainInvisTV` specifiedSpec) tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) @@ -160,7 +179,7 @@ mkConP tn' tvs' con' = do ] return (pat, r) where - pat = rename cn + pat = renameNormalC cn sig = forallT (map (`plainInvisTV` specifiedSpec) tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) @@ -184,7 +203,7 @@ mkConP tn' tvs' con' = do Just f -> return (InfixD f pat : r) return (pat, r') where - pat = mkName (':' : nameBase cn) + pat = renameInfixC cn sig = forallT (map (`plainInvisTV` specifiedSpec) tvs) (cxt (map (\t -> [t| Elt $(varT t) |]) tvs)) @@ -192,18 +211,18 @@ mkConP tn' tvs' con' = do [t| Exp $(foldl' appT (conT tn) (map varT tvs)) |] (map (\t -> [t| Exp $(return t) |]) fs)) -mkConS :: Name -> [TyVarBndr ()] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) -mkConS tn' tvs' prev' next' tag' con' = do +mkConS :: Options -> Name -> [TyVarBndr ()] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec]) +mkConS Options{..} tn' tvs' prev' next' tag' con' = do checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns] case con' of NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next' - RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next' + RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (renameNormalC . fst3) fs) prev' (map thd3 fs) next' InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next' _ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported" where mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkNormalC tn cn tag tvs ps fs ns = do - let pat = rename cn + let pat = renameNormalC cn (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match @@ -211,7 +230,7 @@ mkConS tn' tvs' prev' next' tag' con' = do mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkRecC tn cn tag tvs xs ps fs ns = do - let pat = rename cn + let pat = renameNormalC cn (fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns (fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match @@ -219,9 +238,10 @@ mkConS tn' tvs' prev' next' tag' con' = do mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec]) mkInfixC tn cn tag tvs ps fs ns = do - let pat = mkName (':' : nameBase cn) - (fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns - (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns + let pat = renameInfixC cn + zcn = zencode (nameBase cn) + (fun_build, dec_build) <- mkBuild tn zcn tvs tag ps fs ns + (fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") zcn tvs tag ps fs ns dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match return $ (pat, concat [dec_pat, dec_build, dec_match]) @@ -293,7 +313,7 @@ mkConS tn' tvs' prev' next' tag' con' = do ++ map varE xs ++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1) - tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |] + tagged = [| Exp (SmartExp (Pair (SmartExp (Const (NumScalarType (IntegralNumType (SingleIntegralType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs)) |] body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) [] r <- sequence [ sigD fun sig @@ -314,7 +334,7 @@ mkConS tn' tvs' prev' next' tag' con' = do fun <- newName ("_match" ++ cn) e <- newName "_e" x <- newName "_x" - (ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] [] + (ps,es) <- prj vs [| Prj PairIdxRight $(varE x) |] [] [] unbind <- isExtEnabled RebindableSyntax let eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id @@ -335,17 +355,17 @@ mkConS tn' tvs' prev' next' tag' con' = do (cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs)) [t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |] - matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |] + matchP us = [p| TagRcon _ $(litP (IntegerL (toInteger tag))) $pat |] where pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |] - extract [] _ ps es = return (ps, es) - extract (u:us) x ps es = do + prj [] _ ps es = return (ps, es) + prj (u:us) x ps es = do _u <- newName "_u" let x' = [| Prj PairIdxLeft (SmartExp $x) |] if not u - then extract us x' (wildP:ps) es - else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) + then prj us x' (wildP:ps) es + else prj us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es) vs = reverse $ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ] @@ -374,8 +394,8 @@ fst3 (a,_,_) = a thd3 :: (a,b,c) -> c thd3 (_,_,c) = c -rename :: Name -> Name -rename nm = +defaultRenameNormalC :: Name -> Name +defaultRenameNormalC nm = let split acc [] = (reverse acc, '\0') -- shouldn't happen split acc [l] = (reverse acc, l) @@ -388,6 +408,9 @@ rename nm = '_' -> mkName base _ -> mkName (nm' ++ "_") +defaultRenameInfixC :: Name -> Name +defaultRenameInfixC nm = mkName (':' : nameBase nm) + checkExts :: [Extension] -> Q () checkExts req = do enabled <- extsEnabled diff --git a/src/Data/Array/Accelerate/Pattern/Tuple.hs b/src/Data/Array/Accelerate/Pattern/Tuple.hs new file mode 100644 index 000000000..4f34fe26c --- /dev/null +++ b/src/Data/Array/Accelerate/Pattern/Tuple.hs @@ -0,0 +1,88 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +-- | +-- Module : Data.Array.Accelerate.Pattern.Tuple +-- Copyright : [2018..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Array.Accelerate.Pattern.Tuple ( + + pattern T2, pattern T3, pattern T4, pattern T5, pattern T6, + pattern T7, pattern T8, pattern T9, pattern T10, pattern T11, + pattern T12, pattern T13, pattern T14, pattern T15, pattern T16, + +) where + +import Data.Array.Accelerate.Pattern ( pattern Pattern ) +import Data.Array.Accelerate.Smart +import Data.Array.Accelerate.Sugar.Array +import Data.Array.Accelerate.Sugar.Elt + +import Language.Haskell.TH.Extra hiding ( Exp, Match ) + + +-- Generate polymorphic pattern synonyms to construct and destruct tuples on +-- both Haskell values and embedded expressions. This isn't really necessary but +-- provides for a more consistent interface. +-- +runQ $ + let + mkT :: Int -> Q [Dec] + mkT n = + let xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + res = mkName "r" + xsT = map varT xs + xsP = map varP xs + xsE = map varE xs + name = mkName ('T':show n) + isT = mkName ("IsT" ++ show n) + builder = mkName ("buildT" ++ show n) + matcher = mkName ("matchT" ++ show n) + sig = foldr (\t r -> [t| $t -> $r |]) (varT res) xsT + hdr ts r = foldl appT (conT isT) (ts ++ [r]) + in + sequence + -- Value/Embedded polymorphic pattern synonym + [ patSynSigD name [t| $(hdr xsT (varT res)) => $sig |] + , patSynD name (prefixPatSyn xs) (explBidir [clause [] (normalB (varE builder)) []]) (parensP $ viewP (varE matcher) (tupP xsP)) + , pragCompleteD [name] (Just (tupleTypeName n)) + , pragCompleteD [name] (Just ''Acc) + , pragCompleteD [name] (Just ''Exp) + -- + , classD (return []) isT (map plainTV (xs ++ [res])) [funDep [res] xs] + [ sigD builder sig + , sigD matcher [t| $(varT res) -> $(tupT xsT) |] + ] + , instanceD (return []) [t| $(hdr xsT (tupT xsT)) |] + [ funD builder [ clause xsP (normalB (tupE xsE)) [] ] + , funD matcher [ clause [] (normalB [| id |]) [] ] + ] + , instanceD (sequence ( [t| $(varT res) ~ $(tupT xsT) |] : map (\x -> [t| Elt $x |]) xsT )) [t| $(hdr (map (\x -> [t| Exp $x |]) xsT) [t| Exp $(varT res) |]) |] + [ funD builder [ clause xsP (normalB [| Pattern $(tupE xsE) |]) [] ] + , funD matcher [ clause [conP (mkName "Pattern") [tupP xsP]] (normalB (tupE xsE)) []] + ] + , instanceD (sequence ( [t| $(varT res) ~ $(tupT xsT) |] : map (\x -> [t| Arrays $x |]) xsT)) [t| $(hdr (map (\x -> [t| Acc $x |]) xsT) [t| Acc $(varT res) |]) |] + [ funD builder [ clause xsP (normalB [| Pattern $(tupE xsE) |]) [] ] + , funD matcher [ clause [conP (mkName "Pattern") [tupP xsP]] (normalB (tupE xsE)) []] + ] + ] + in + concat <$> mapM mkT [2..16] + diff --git a/src/Data/Array/Accelerate/Prelude.hs b/src/Data/Array/Accelerate/Prelude.hs index a833ad12c..6cf549119 100644 --- a/src/Data/Array/Accelerate/Prelude.hs +++ b/src/Data/Array/Accelerate/Prelude.hs @@ -121,15 +121,17 @@ module Data.Array.Accelerate.Prelude ( import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Language import Data.Array.Accelerate.Lift -import Data.Array.Accelerate.Pattern import Data.Array.Accelerate.Pattern.Maybe +import Data.Array.Accelerate.Pattern.Shape +import Data.Array.Accelerate.Pattern.Tuple import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array ( Arrays, Array, Scalar, Vector, Segments, fromList ) import Data.Array.Accelerate.Sugar.Elt -import Data.Array.Accelerate.Sugar.Shape ( Shape, Slice, Z(..), (:.)(..), All(..), DIM1, DIM2, empty ) +import Data.Array.Accelerate.Sugar.Shape ( Shape, Slice, DIM1, DIM2, empty ) import Data.Array.Accelerate.Type import Data.Array.Accelerate.Classes.Eq +import Data.Array.Accelerate.Classes.FromBool import Data.Array.Accelerate.Classes.FromIntegral import Data.Array.Accelerate.Classes.Integral import Data.Array.Accelerate.Classes.Num @@ -138,10 +140,7 @@ import Data.Array.Accelerate.Classes.Ord import Data.Array.Accelerate.Data.Bits import Lens.Micro ( Lens', (&), (^.), (.~), (+~), (-~), lens, over ) -import Prelude ( (.), ($), Maybe(..), const, id, flip ) -#if __GLASGOW_HASKELL__ >= 904 -import Data.Type.Equality -#endif +import Prelude ( (.), ($), const, id, flip ) -- $setup @@ -708,7 +707,7 @@ fold1All f arr = fold1 f (flatten arr) -- 40, 170, 0, 138] -- foldSeg - :: forall sh e i. (Shape sh, Elt e, Elt i, i ~ EltR i, IsIntegral i) + :: forall sh e i. (Shape sh, Elt e, Num i, IsSingleIntegral (EltR i)) => (Exp e -> Exp e -> Exp e) -> Exp e -> Acc (Array (sh:.Int) e) @@ -717,17 +716,17 @@ foldSeg foldSeg f z arr seg = foldSeg' f z arr (scanl plus zero seg) where (plus, zero) = - case integralType @i of - TypeInt{} -> ((+), 0) - TypeInt8{} -> ((+), 0) - TypeInt16{} -> ((+), 0) - TypeInt32{} -> ((+), 0) - TypeInt64{} -> ((+), 0) - TypeWord{} -> ((+), 0) - TypeWord8{} -> ((+), 0) - TypeWord16{} -> ((+), 0) - TypeWord32{} -> ((+), 0) - TypeWord64{} -> ((+), 0) + case singleIntegralType @(EltR i) of + TypeInt8{} -> ((+), 0) + TypeInt16{} -> ((+), 0) + TypeInt32{} -> ((+), 0) + TypeInt64{} -> ((+), 0) + TypeInt128{} -> ((+), 0) + TypeWord8{} -> ((+), 0) + TypeWord16{} -> ((+), 0) + TypeWord32{} -> ((+), 0) + TypeWord64{} -> ((+), 0) + TypeWord128{} -> ((+), 0) -- | Variant of 'foldSeg' that requires /all/ segments of the reduced array @@ -735,7 +734,7 @@ foldSeg f z arr seg = foldSeg' f z arr (scanl plus zero seg) -- descriptor species the length of each of the logical sub-arrays. -- fold1Seg - :: forall sh e i. (Shape sh, Elt e, Elt i, i ~ EltR i, IsIntegral i) + :: forall sh e i. (Shape sh, Elt e, Num i, IsSingleIntegral (EltR i)) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) @@ -745,17 +744,17 @@ fold1Seg f arr seg = fold1Seg' f arr (scanl plus zero seg) plus :: Exp i -> Exp i -> Exp i zero :: Exp i (plus, zero) = - case integralType @(EltR i) of - TypeInt{} -> ((+), 0) - TypeInt8{} -> ((+), 0) - TypeInt16{} -> ((+), 0) - TypeInt32{} -> ((+), 0) - TypeInt64{} -> ((+), 0) - TypeWord{} -> ((+), 0) - TypeWord8{} -> ((+), 0) - TypeWord16{} -> ((+), 0) - TypeWord32{} -> ((+), 0) - TypeWord64{} -> ((+), 0) + case singleIntegralType @(EltR i) of + TypeInt8{} -> ((+), 0) + TypeInt16{} -> ((+), 0) + TypeInt32{} -> ((+), 0) + TypeInt64{} -> ((+), 0) + TypeInt128{} -> ((+), 0) + TypeWord8{} -> ((+), 0) + TypeWord16{} -> ((+), 0) + TypeWord32{} -> ((+), 0) + TypeWord64{} -> ((+), 0) + TypeWord128{} -> ((+), 0) -- Specialised reductions @@ -807,14 +806,14 @@ any f = or . map f and :: Shape sh => Acc (Array (sh:.Int) Bool) -> Acc (Array sh Bool) -and = fold (&&) True_ +and = fold (&&) True -- | Check if any element along the innermost dimension is 'True'. -- or :: Shape sh => Acc (Array (sh:.Int) Bool) -> Acc (Array sh Bool) -or = fold (||) False_ +or = fold (||) False -- | Compute the sum of elements along the innermost dimension of the array. To -- find the sum of the entire array, 'flatten' it first. @@ -985,7 +984,7 @@ scanlSeg -> Acc (Array (sh:.Int) e) scanlSeg f z arr seg = if null arr || null flags - then fill (sh ::. sz + length seg) z + then fill (sh :. sz + length seg) z else scanl1Seg f arr' seg' where -- Segmented exclusive scan is implemented by first injecting the seed @@ -996,11 +995,11 @@ scanlSeg f z arr seg = -- overlaying the input data in all places other than at the start of -- a segment. -- - sh ::. sz = shape arr + sh :. sz = shape arr seg' = map (+1) seg arr' = permute const - (fill (sh ::. sz + length seg) z) - (\(sx ::. i) -> Just_ (sx ::. i + fromIntegral (inc ! I1 i))) + (fill (sh :. sz + length seg) z) + (\(sx :. i) -> Just (sx :. i + fromIntegral (inc ! I1 i))) (take (length flags) arr) -- Each element in the segments must be shifted to the right one additional @@ -1058,7 +1057,7 @@ scanl'Seg -> Acc (Array (sh:.Int) e, Array (sh:.Int) e) scanl'Seg f z arr seg = if null arr - then T2 arr (fill (indexTail (shape arr) ::. length seg) z) + then T2 arr (fill (indexTail (shape arr) :. length seg) z) else T2 body sums where -- Segmented scan' is implemented by deconstructing a segmented exclusive @@ -1078,8 +1077,8 @@ scanl'Seg f z arr seg = seg' = map (+1) seg tails = zipWith (+) seg $ prescanl (+) 0 seg' sums = backpermute - (indexTail (shape arr') ::. length seg) - (\(sz ::. i) -> sz ::. fromIntegral (tails ! I1 i)) + (indexTail (shape arr') :. length seg) + (\(sz :. i) -> sz :. fromIntegral (tails ! I1 i)) arr' -- Slice out the body of each segment. @@ -1093,13 +1092,13 @@ scanl'Seg f z arr seg = offset = scanl1 (+) seg inc = scanl1 (+) $ permute (+) (fill (I1 $ size arr + 1) 0) - (\ix -> Just_ (index1' (offset ! ix))) + (\ix -> Just (index1' (offset ! ix))) (fill (shape seg) (1 :: Exp i)) len = offset ! I1 (length offset - 1) body = backpermute - (indexTail (shape arr) ::. fromIntegral len) - (\(sz ::. i) -> sz ::. i + fromIntegral (inc ! I1 i)) + (indexTail (shape arr) :. fromIntegral len) + (\(sz :. i) -> sz :. i + fromIntegral (inc ! I1 i)) arr' @@ -1136,7 +1135,7 @@ scanl'Seg f z arr seg = -- 40, 41, 83, 126, 170, 45, 91, 138] -- scanl1Seg - :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) + :: forall sh i e. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) @@ -1144,7 +1143,7 @@ scanl1Seg scanl1Seg f arr seg = map snd . scanl1 (segmentedL f) - $ zip (replicate (lift (indexTail (shape arr) :. All)) (mkHeadFlags seg)) arr + $ zip (replicate @(sh :. All) (indexTail (shape arr) :. All) (mkHeadFlags seg)) arr -- |Segmented version of 'prescanl'. -- @@ -1206,10 +1205,10 @@ scanrSeg -> Acc (Array (sh:.Int) e) scanrSeg f z arr seg = if null arr || null flags - then fill (sh ::. sz + length seg) z + then fill (sh :. sz + length seg) z else scanr1Seg f arr' seg' where - sh ::. sz = shape arr + sh :. sz = shape arr -- Using technique described for 'scanlSeg', where we intersperse the array -- with the seed element at the start of each segment, and then perform an @@ -1220,8 +1219,8 @@ scanrSeg f z arr seg = seg' = map (+1) seg arr' = permute const - (fill (sh ::. sz + length seg) z) - (\(sx ::. i) -> Just_ (sx ::. i + fromIntegral (inc !! i) - 1)) + (fill (sh :. sz + length seg) z) + (\(sx :. i) -> Just (sx :. i + fromIntegral (inc !! i) - 1)) (drop (sz - length flags) arr) @@ -1265,7 +1264,7 @@ scanr'Seg -> Acc (Array (sh:.Int) e, Array (sh:.Int) e) scanr'Seg f z arr seg = if null arr - then T2 arr (fill (indexTail (shape arr) ::. length seg) z) + then T2 arr (fill (indexTail (shape arr) :. length seg) z) else T2 body sums where -- Using technique described for scanl'Seg @@ -1276,16 +1275,16 @@ scanr'Seg f z arr seg = seg' = map (+1) seg heads = prescanl (+) 0 seg' sums = backpermute - (indexTail (shape arr') ::. length seg) - (\(sz ::.i) -> sz ::. fromIntegral (heads ! I1 i)) + (indexTail (shape arr') :. length seg) + (\(sz :.i) -> sz :. fromIntegral (heads ! I1 i)) arr' -- body segments flags = mkHeadFlags seg inc = scanl1 (+) flags body = backpermute - (indexTail (shape arr) ::. indexHead (shape flags)) - (\(sz ::. i) -> sz ::. i + fromIntegral (inc ! I1 i)) + (indexTail (shape arr) :. indexHead (shape flags)) + (\(sz :. i) -> sz :. i + fromIntegral (inc ! I1 i)) arr' @@ -1313,7 +1312,7 @@ scanr'Seg f z arr seg = -- 40, 170, 129, 87, 44, 138, 93, 47] -- scanr1Seg - :: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) + :: forall sh i e. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int) => (Exp e -> Exp e -> Exp e) -> Acc (Array (sh:.Int) e) -> Acc (Segments i) @@ -1321,7 +1320,7 @@ scanr1Seg scanr1Seg f arr seg = map snd . scanr1 (segmentedR f) - $ zip (replicate (lift (indexTail (shape arr) :. All)) (mkTailFlags seg)) arr + $ zip (replicate @(sh :. All) (indexTail (shape arr) :. All) (mkTailFlags seg)) arr -- |Segmented version of 'prescanr'. @@ -1367,7 +1366,7 @@ mkHeadFlags -> Acc (Segments i) mkHeadFlags seg = init - $ permute (+) zeros (\ix -> Just_ (index1' (offset ! ix))) ones + $ permute (+) zeros (\ix -> Just (index1' (offset ! ix))) ones where T2 offset len = scanl' (+) 0 seg zeros = fill (index1' $ the len + 1) 0 @@ -1382,7 +1381,7 @@ mkTailFlags -> Acc (Segments i) mkTailFlags seg = init - $ permute (+) zeros (\ix -> Just_ (index1' (the len - 1 - offset ! ix))) ones + $ permute (+) zeros (\ix -> Just (index1' (the len - 1 - offset ! ix))) ones where T2 offset len = scanr' (+) 0 seg zeros = fill (index1' $ the len + 1) 0 @@ -1418,7 +1417,7 @@ segmentedR f y x = segmentedL (flip f) x y -- base library, because it results in too many ambiguity errors. -- index1' :: (Integral i, FromIntegral i Int) => Exp i -> Exp DIM1 -index1' i = lift (Z :. fromIntegral i) +index1' i = Z :. fromIntegral i -- Reshaping of arrays @@ -1658,15 +1657,15 @@ compact keep arr -- for the offset indices. | Just Refl <- matchShapeType @sh @Z = let - T2 target len = scanl' (+) 0 (map boolToInt keep) + T2 target len = scanl' (+) 0 (map fromBool keep) prj ix = if keep!ix - then Just_ (I1 (target!ix)) - else Nothing_ + then Just (I1 (target!ix)) + else Nothing dummy = fill (I1 (the len)) undef result = permute const dummy prj arr in if null arr - then T2 emptyArray (fill Z_ 0) + then T2 emptyArray (fill Z 0) else if the len == unindex1 (shape arr) then T2 arr len @@ -1675,11 +1674,11 @@ compact keep arr compact keep arr = let sz = indexTail (shape arr) - T2 target len = scanl' (+) 0 (map boolToInt keep) + T2 target len = scanl' (+) 0 (map fromBool keep) T2 offset valid = scanl' (+) 0 (flatten len) prj ix = if keep!ix - then Just_ (I1 (offset !! (toIndex sz (indexTail ix)) + target!ix)) - else Nothing_ + then Just (I1 (offset !! (toIndex sz (indexTail ix)) + target!ix)) + else Nothing dummy = fill (I1 (the valid)) undef result = permute const dummy prj arr in @@ -1760,7 +1759,7 @@ scatter -> Acc (Vector e) scatter to defaults input = permute const defaults pf input' where - pf ix = Just_ (I1 (to ! ix)) + pf ix = Just (I1 (to ! ix)) input' = backpermute (shape to `intersect` shape input) id input @@ -1789,8 +1788,8 @@ scatterIf to maskV pred defaults input = permute const defaults pf input' where input' = backpermute (shape to `intersect` shape input) id input pf ix = if pred (maskV ! ix) - then Just_ (I1 (to ! ix)) - else Nothing_ + then Just (I1 (to ! ix)) + else Nothing -- Permutations @@ -2226,9 +2225,9 @@ instance Arrays a => IfThenElse (Exp Bool) (Acc a) where -- argument. For example, given the function: -- -- > example1 :: Exp (Maybe Bool) -> Exp Int --- > example1 Nothing_ = 0 --- > example1 (Just_ False_) = 1 --- > example1 (Just_ True_) = 2 +-- > example1 Nothing = 0 +-- > example1 (Just False) = 1 +-- > example1 (Just True) = 2 -- -- In order to use this function it must be applied to the 'match' -- operator: @@ -2239,14 +2238,14 @@ instance Arrays a => IfThenElse (Exp Bool) (Acc a) where -- case statements inline. For example, instead of this: -- -- > example2 x = case f x of --- > Nothing_ -> ... -- error: embedded pattern synonym... --- > Just_ y -> ... -- ...used outside of 'match' context +-- > Nothing -> ... -- error: embedded pattern synonym... +-- > Just y -> ... -- ...used outside of 'match' context -- -- This can be written instead as: -- -- > example3 x = f x & match \case --- > Nothing_ -> ... --- > Just_ y -> ... +-- > Nothing -> ... +-- > Just y -> ... -- -- And utilising the @LambdaCase@ and @BlockArguments@ syntactic extensions. -- @@ -2263,27 +2262,25 @@ instance Arrays a => IfThenElse (Exp Bool) (Acc a) where -- -- > isNone :: Elt a => Exp (Option a) -> Exp Bool -- > isNone = match \case --- > None_ -> True_ --- > Some_{} -> False_ +-- > None_ -> True +-- > Some_{} -> False -- -- @since 1.3.0.0 -- match :: Matching f => f -> f match f = mkFun (mkMatch f) id -data Args f where - (:->) :: Exp a -> Args b -> Args (Exp a -> b) - Result :: Args (Exp a) - class Matching a where type ResultT a - mkMatch :: a -> Args a -> Exp (ResultT a) - mkFun :: (Args f -> Exp (ResultT a)) - -> (Args a -> Args f) + data ArgsR a + mkMatch :: a -> ArgsR a -> Exp (ResultT a) + mkFun :: (ArgsR f -> Exp (ResultT a)) + -> (ArgsR a -> ArgsR f) -> a instance Elt a => Matching (Exp a) where type ResultT (Exp a) = a + data ArgsR (Exp a) = Result mkFun f k = f (k Result) mkMatch (Exp e) Result = @@ -2293,6 +2290,7 @@ instance Elt a => Matching (Exp a) where instance (Elt e, Matching r) => Matching (Exp e -> r) where type ResultT (Exp e -> r) = ResultT r + data ArgsR (Exp e -> r) = Exp e :-> ArgsR r mkFun f k x = mkFun f (\xs -> k (x :-> xs)) mkMatch f (x@(Exp p) :-> xs) = @@ -2342,7 +2340,7 @@ sfoldl :: (Shape sh, Elt a, Elt b) -> Exp a sfoldl f z ix xs = let n = indexHead (shape xs) - step (T2 i acc) = T2 (i+1) (acc `f` (xs ! (ix ::. i))) + step (T2 i acc) = T2 (i+1) (acc `f` (xs ! (ix :. i))) in snd $ while (\v -> fst v < n) step (T2 0 z) @@ -2385,17 +2383,17 @@ uncurry f t = let (x, y) = unlift t in f x y -- | The one index for a rank-0 array. -- index0 :: Exp Z -index0 = lift Z +index0 = Z -- | Turn an 'Int' expression into a rank-1 indexing expression. -- index1 :: Elt i => Exp i -> Exp (Z :. i) -index1 i = lift (Z :. i) +index1 i = Z :. i -- | Turn a rank-1 indexing expression into an 'Int' expression. -- unindex1 :: Elt i => Exp (Z :. i) -> Exp i -unindex1 ix = let Z :. i = unlift ix in i +unindex1 (Z :. i) = i -- | Creates a rank-2 index from two Exp Int`s -- @@ -2404,7 +2402,7 @@ index2 => Exp i -> Exp i -> Exp (Z :. i :. i) -index2 i j = lift (Z :. i :. j) +index2 i j = Z :. i :. j -- | Destructs a rank-2 index to an Exp tuple of two Int`s. -- @@ -2412,7 +2410,7 @@ unindex2 :: Elt i => Exp (Z :. i :. i) -> Exp (i, i) -unindex2 (Z_ ::. i ::. j) = T2 i j +unindex2 (Z :. i :. j) = T2 i j -- | Create a rank-3 index from three Exp Int`s -- @@ -2422,14 +2420,14 @@ index3 -> Exp i -> Exp i -> Exp (Z :. i :. i :. i) -index3 k j i = Z_ ::. k ::. j ::. i +index3 k j i = Z :. k :. j :. i -- | Destruct a rank-3 index into an Exp tuple of Int`s unindex3 :: Elt i => Exp (Z :. i :. i :. i) -> Exp (i, i, i) -unindex3 (Z_ ::. k ::. j ::. i) = T3 k j i +unindex3 (Z :. k :. j :. i) = T3 k j i -- Array operations with a scalar result @@ -2497,7 +2495,7 @@ length = unindex1 . shape -- new = -- let m = c2-c1 -- put i = let s = sieves ! i --- in s >= 0 && s < m ? (Just_ (I1 s), Nothing_) +-- in s >= 0 && s < m ? (Just (I1 s), Nothing) -- in -- afst -- $ filter (> 0) @@ -2532,7 +2530,7 @@ expand f g xs = else let n = m + 1 - put ix = Just_ (I1 (offset ! ix)) + put ix = Just (I1 (offset ! ix)) head_flags :: Acc (Vector Int) head_flags = permute const (fill (I1 n) 0) put (fill (shape szs) 1) @@ -2554,7 +2552,7 @@ expand f g xs = -- also the same, which is undefined behaviour (\ix -> if szs ! ix > 0 then put ix - else Nothing_) + else Nothing) $ enumFromN (shape xs) 0 in zipWith g (gather iotas xs) idxs @@ -2618,14 +2616,14 @@ emptyArray = fill (constant empty) undef -- Imported from `lens-accelerate` (which provides more general Field instances) -- _1 :: forall sh. Elt sh => Lens' (Exp (sh:.Int)) (Exp Int) -_1 = lens (\ix -> let _ :. x = unlift ix :: Exp sh :. Exp Int in x) - (\ix x -> let sh :. _ = unlift ix :: Exp sh :. Exp Int in lift (sh :. x)) +_1 = lens (\ix -> let _ :. x = ix in x) + (\ix x -> let sh :. _ = ix in sh :. x) _2 :: forall sh. Elt sh => Lens' (Exp (sh:.Int:.Int)) (Exp Int) -_2 = lens (\ix -> let _ :. y :. _ = unlift ix :: Exp sh :. Exp Int :. Exp Int in y) - (\ix y -> let sh :. _ :. x = unlift ix :: Exp sh :. Exp Int :. Exp Int in lift (sh :. y :. x)) +_2 = lens (\ix -> let _ :. y :. _ = ix in y) + (\ix y -> let sh :. _ :. x = ix in sh :. y :. x) _3 :: forall sh. Elt sh => Lens' (Exp (sh:.Int:.Int:.Int)) (Exp Int) -_3 = lens (\ix -> let _ :. z :. _ :. _ = unlift ix :: Exp sh :. Exp Int :. Exp Int :. Exp Int in z) - (\ix z -> let sh :. _ :. y :. x = unlift ix :: Exp sh :. Exp Int :. Exp Int :. Exp Int in lift (sh :. z :. y :. x)) +_3 = lens (\ix -> let _ :. z :. _ :. _ = ix in z) + (\ix z -> let sh :. _ :. y :. x = ix in sh :. z :. y :. x) diff --git a/src/Data/Array/Accelerate/Pretty/Graphviz.hs b/src/Data/Array/Accelerate/Pretty/Graphviz.hs index ca63fd323..04547e616 100644 --- a/src/Data/Array/Accelerate/Pretty/Graphviz.hs +++ b/src/Data/Array/Accelerate/Pretty/Graphviz.hs @@ -192,9 +192,9 @@ prettyDelayedOpenAcc detail ctx aenv (Manifest pacc) = Avar ix -> pnode (avar ix) Alet lhs bnd body -> do bnd'@(PNode ident _ _) <- prettyDelayedOpenAcc detail context0 aenv bnd - (aenv1, a) <- prettyLetALeftHandSide ident aenv lhs - _ <- mkNode bnd' (Just a) - body' <- prettyDelayedOpenAcc detail context0 aenv1 body + (aenv1, a) <- prettyLetALeftHandSide ident aenv lhs + _ <- mkNode bnd' (Just a) + body' <- prettyDelayedOpenAcc detail context0 aenv1 body return body' Acond p t e -> do @@ -216,14 +216,16 @@ prettyDelayedOpenAcc detail ctx aenv (Manifest pacc) = p' <- prettyDelayedAfun detail aenv p f' <- prettyDelayedAfun detail aenv f -- - let PNode _ (Leaf (Nothing,xb)) fvs = x' - loop = nest 2 (sep ["awhile", pretty p', pretty f', xb ]) - return $ PNode ident (Leaf (Nothing,loop)) fvs + case x' of + PNode _ (Leaf (Nothing,xb)) fvs -> let loop = nest 2 (sep ["awhile", pretty p', pretty f', xb ]) + in return $ PNode ident (Leaf (Nothing,loop)) fvs + _ -> internalError "unexpected node" Apair a1 a2 -> genNodeId >>= prettyDelayedApair detail aenv a1 a2 Anil -> "()" .$ [] Atrace (Message _ _ msg) as bs -> "atrace" .$ [ return $ PDoc (pretty msg) [], ppA as, ppA bs ] + Acoerce _ bR a -> "coerce" .$ [ return $ PDoc ("@" <> pretty (show bR)) [], ppA a ] Use repr arr -> "use" .$ [ return $ PDoc (prettyArray repr arr) [] ] Unit _ e -> "unit" .$ [ ppE e ] Generate _ sh f -> "generate" .$ [ ppE sh, ppF f ] @@ -519,20 +521,21 @@ fvOpenExp env aenv = fv fv Evar{} = [] fv Undef{} = [] fv Const{} = [] - fv PrimConst{} = [] fv (PrimApp _ x) = fv x - fv (Pair e1 e2) = concat [ fv e1, fv e2] + fv (Pair e1 e2) = concat [ fv e1, fv e2 ] fv Nil = [] - fv (VecPack _ e) = fv e - fv (VecUnpack _ e) = fv e + fv (Extract _ _ v i) = concat [ fv v, fv i ] + fv (Insert _ _ v i x) = concat [ fv v, fv i, fv x ] + fv (Shuffle _ _ x y i) = concat [ fv x, fv y, fv i ] + fv (Select _ m x y) = concat [ fv m, fv x, fv y ] fv (IndexSlice _ slix sh) = concat [ fv slix, fv sh ] fv (IndexFull _ slix sh) = concat [ fv slix, fv sh ] fv (ToIndex _ sh ix) = concat [ fv sh, fv ix ] fv (FromIndex _ sh ix) = concat [ fv sh, fv ix ] fv (ShapeSize _ sh) = fv sh fv Foreign{} = [] - fv (Case e rhs def) = concat [ fv e, concat [ fv c | (_,c) <- rhs ], maybe [] fv def ] + fv (Case _ e rhs def) = concat [ fv e, concat [ fv c | (_,c) <- rhs ], maybe [] fv def ] fv (Cond p t e) = concat [ fv p, fv t, fv e ] fv (While p f x) = concat [ fvF p, fvF f, fv x ] - fv (Coerce _ _ e) = fv e + fv (Bitcast _ _ e) = fv e diff --git a/src/Data/Array/Accelerate/Pretty/Print.hs b/src/Data/Array/Accelerate/Pretty/Print.hs index 965aecd99..34b7dd93a 100644 --- a/src/Data/Array/Accelerate/Pretty/Print.hs +++ b/src/Data/Array/Accelerate/Pretty/Print.hs @@ -59,6 +59,7 @@ import Data.Array.Accelerate.AST hiding ( Dir import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.AST.Var +import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Stencil @@ -194,6 +195,7 @@ prettyPreOpenAcc config ctx prettyAcc extractAcc aenv pacc = Atrace (Message _ _ msg) as bs -> ppN "atrace" .$ [ fromString (show msg), ppA as, ppA bs ] + Acoerce _ bR a -> ppN "coerce" .$ [ "@" <> pretty (show bR), ppA a ] Aforeign _ ff _ a -> ppN "aforeign" .$ [ pretty (strForeign ff), ppA a ] Awhile p f a -> ppN "awhile" .$ [ ppAF p, ppAF f, ppA a ] Use repr arr -> ppN "use" .$ [ prettyArray repr arr ] @@ -412,13 +414,14 @@ prettyOpenExp ctx env aenv exp = op = primOperator f op' = isInfix op ? (Operator (parens (opName op)) App L 10, op) -- - PrimConst c -> prettyPrimConst c Const tp c -> prettyConst (TupRsingle tp) c Pair{} -> prettyTuple ctx env aenv exp Nil -> "()" - VecPack _ e -> ppF1 "pack" (ppE e) - VecUnpack _ e -> ppF1 "unpack" (ppE e) - Case x xs d -> prettyCase env aenv x xs d + Extract _ _ v i -> ppF2 (Operator "#" Infix L 9) (ppE v) (ppE i) + Insert{} -> prettyInsert ctx env aenv exp + Shuffle _ _ x y i -> ppF3 "shuffle" (ppE x) (ppE y) (ppE i) + Select _ m x y -> ppF3 "select" (ppE m) (ppE x) (ppE y) + Case tR x xs d -> prettyCase env aenv tR x xs d Cond p t e -> flatAlt multi single where p' = ppE p context0 @@ -442,7 +445,7 @@ prettyOpenExp ctx env aenv exp = ShapeSize _ sh -> ppF1 "shapeSize" (ppE sh) Index arr ix -> ppF2 (Operator (pretty '!') Infix L 9) (ppA arr) (ppE ix) LinearIndex arr ix -> ppF2 (Operator "!!" Infix L 9) (ppA arr) (ppE ix) - Coerce _ tp x -> ppF1 (Operator (withTypeRep tp "coerce") App L 10) (ppE x) + Bitcast _ tp x -> ppF1 (Operator (withTypeRep tp "bitcast") App L 10) (ppE x) Undef tp -> withTypeRep tp "undef" where @@ -558,24 +561,62 @@ prettyTuple ctx env aenv exp = case collect exp of prettyCase :: Val env -> Val aenv + -> TagType tag -> OpenExp env aenv a - -> [(TAG, OpenExp env aenv b)] + -> [(tag, OpenExp env aenv b)] -> Maybe (OpenExp env aenv b) -> Adoc -prettyCase env aenv x xs def +prettyCase env aenv tagR x xs def = hang shiftwidth $ vsep [ case_ <+> x' <+> of_ , flatAlt (vcat xs') (encloseSep "{ " " }" "; " xs') ] where x' = prettyOpenExp context0 env aenv x - xs' = map (\(t,e) -> pretty t <+> "->" <+> prettyOpenExp context0 env aenv e) xs + tR = case tagR of + TagBit -> scalarType + TagWord8 -> scalarType + TagWord16 -> scalarType + xs' = map (\(t,e) -> prettyConst (TupRsingle tR) t <+> "->" <+> prettyOpenExp context0 env aenv e) xs ++ case def of Nothing -> [] Just d -> ["_" <+> "->" <+> prettyOpenExp context0 env aenv d] -{- - +prettyInsert + :: forall env aenv t. + Context + -> Val env + -> Val aenv + -> OpenExp env aenv t + -> Adoc +prettyInsert ctx env aenv exp = + case collect exp of + Just (c, xs) -> align $ parensIf (ctxPrecedence ctx > 0) ("V" <> pretty c <+> align (sep (reverse xs))) + Nothing -> align $ ppInsert exp + where + ppInsert :: OpenExp env aenv t' -> Adoc + ppInsert (Insert _ _ v i x) = + let v' = prettyOpenExp ctx env aenv v + i' = prettyOpenExp ctx env aenv i + x' = prettyOpenExp ctx env aenv x + in + parensIf (ctxPrecedence ctx > 0) + $ hang 2 + $ sep [ "insert", brackets i', x', v'] + ppInsert e = + prettyOpenExp context0 env aenv e + + collect :: OpenExp env aenv t' -> Maybe (Word8, [Adoc]) + collect Undef{} + = Just (0, []) + collect (Insert _ _ v i x) + | Just (i', xs) <- collect v + , Just Refl <- matchOpenExp i (Const scalarType i') + = Just (i'+1, prettyOpenExp app env aenv x : xs) + collect _ + = Nothing + +{-- prettyAtuple :: forall acc aenv arrs. PrettyAcc acc @@ -597,17 +638,16 @@ prettyAtuple prettyAcc extractAcc aenv0 acc = case collect acc of | Just tup <- collect $ extractAcc a1 = Just $ tup ++ [prettyAcc app aenv0 a2] collect _ = Nothing --} +--} prettyConst :: TypeR e -> e -> Adoc prettyConst tp x = - let y = showElt tp x - in parensIf (any isSpace y) (pretty y) - -prettyPrimConst :: PrimConst a -> Adoc -prettyPrimConst PrimMinBound{} = "minBound" -prettyPrimConst PrimMaxBound{} = "maxBound" -prettyPrimConst PrimPi{} = "pi" + let y = showElt tp x + -- + isVec [] = False + isVec xs = head xs == '<' && last xs == '>' + in + parensIf (any isSpace y && not (isVec y)) (pretty y) -- Primitive operators @@ -665,68 +705,82 @@ isInfix :: Operator -> Bool isInfix Operator{..} = opFixity == Infix primOperator :: PrimFun a -> Operator -primOperator PrimAdd{} = Operator (pretty '+') Infix L 6 -primOperator PrimSub{} = Operator (pretty '-') Infix L 6 -primOperator PrimMul{} = Operator (pretty '*') Infix L 7 -primOperator PrimNeg{} = Operator (pretty '-') Prefix L 6 -- Haskell's only prefix operator -primOperator PrimAbs{} = Operator "abs" App L 10 -primOperator PrimSig{} = Operator "signum" App L 10 -primOperator PrimQuot{} = Operator "quot" App L 10 -primOperator PrimRem{} = Operator "rem" App L 10 -primOperator PrimQuotRem{} = Operator "quotRem" App L 10 -primOperator PrimIDiv{} = Operator "div" App L 10 -primOperator PrimMod{} = Operator "mod" App L 10 -primOperator PrimDivMod{} = Operator "divMod" App L 10 -primOperator PrimBAnd{} = Operator ".&." Infix L 7 -primOperator PrimBOr{} = Operator ".|." Infix L 5 -primOperator PrimBXor{} = Operator "xor" App L 10 -primOperator PrimBNot{} = Operator "complement" App L 10 -primOperator PrimBShiftL{} = Operator "shiftL" App L 10 -primOperator PrimBShiftR{} = Operator "shiftR" App L 10 -primOperator PrimBRotateL{} = Operator "rotateL" App L 10 -primOperator PrimBRotateR{} = Operator "rotateR" App L 10 -primOperator PrimPopCount{} = Operator "popCount" App L 10 -primOperator PrimCountLeadingZeros{} = Operator "countLeadingZeros" App L 10 -primOperator PrimCountTrailingZeros{} = Operator "countTrailingZeros" App L 10 -primOperator PrimFDiv{} = Operator (pretty '/') Infix L 7 -primOperator PrimRecip{} = Operator "recip" App L 10 -primOperator PrimSin{} = Operator "sin" App L 10 -primOperator PrimCos{} = Operator "cos" App L 10 -primOperator PrimTan{} = Operator "tan" App L 10 -primOperator PrimAsin{} = Operator "asin" App L 10 -primOperator PrimAcos{} = Operator "acos" App L 10 -primOperator PrimAtan{} = Operator "atan" App L 10 -primOperator PrimSinh{} = Operator "sinh" App L 10 -primOperator PrimCosh{} = Operator "cosh" App L 10 -primOperator PrimTanh{} = Operator "tanh" App L 10 -primOperator PrimAsinh{} = Operator "asinh" App L 10 -primOperator PrimAcosh{} = Operator "acosh" App L 10 -primOperator PrimAtanh{} = Operator "atanh" App L 10 -primOperator PrimExpFloating{} = Operator "exp" App L 10 -primOperator PrimSqrt{} = Operator "sqrt" App L 10 -primOperator PrimLog{} = Operator "log" App L 10 -primOperator PrimFPow{} = Operator "**" Infix R 8 -primOperator PrimLogBase{} = Operator "logBase" App L 10 -primOperator PrimTruncate{} = Operator "truncate" App L 10 -primOperator PrimRound{} = Operator "round" App L 10 -primOperator PrimFloor{} = Operator "floor" App L 10 -primOperator PrimCeiling{} = Operator "ceiling" App L 10 -primOperator PrimAtan2{} = Operator "atan2" App L 10 -primOperator PrimIsNaN{} = Operator "isNaN" App L 10 -primOperator PrimIsInfinite{} = Operator "isInfinite" App L 10 -primOperator PrimLt{} = Operator "<" Infix N 4 -primOperator PrimGt{} = Operator ">" Infix N 4 -primOperator PrimLtEq{} = Operator "<=" Infix N 4 -primOperator PrimGtEq{} = Operator ">=" Infix N 4 -primOperator PrimEq{} = Operator "==" Infix N 4 -primOperator PrimNEq{} = Operator "/=" Infix N 4 -primOperator PrimMax{} = Operator "max" App L 10 -primOperator PrimMin{} = Operator "min" App L 10 -primOperator PrimLAnd = Operator "&&" Infix R 3 -primOperator PrimLOr = Operator "||" Infix R 2 -primOperator PrimLNot = Operator "not" App L 10 -primOperator PrimFromIntegral{} = Operator "fromIntegral" App L 10 -primOperator PrimToFloating{} = Operator "toFloating" App L 10 +primOperator = \case + PrimAdd{} -> Operator (pretty '+') Infix L 6 + PrimSub{} -> Operator (pretty '-') Infix L 6 + PrimMul{} -> Operator (pretty '*') Infix L 7 + PrimNeg{} -> Operator (pretty '-') Prefix L 6 -- Haskell's only prefix operator + PrimAbs{} -> Operator "abs" App L 10 + PrimSig{} -> Operator "signum" App L 10 + PrimVAdd{} -> Operator "vadd" App L 10 + PrimVMul{} -> Operator "vmul" App L 10 + PrimQuot{} -> Operator "quot" App L 10 + PrimRem{} -> Operator "rem" App L 10 + PrimQuotRem{} -> Operator "quotRem" App L 10 + PrimIDiv{} -> Operator "div" App L 10 + PrimMod{} -> Operator "mod" App L 10 + PrimDivMod{} -> Operator "divMod" App L 10 + PrimBAnd{} -> Operator ".&." Infix L 7 + PrimBOr{} -> Operator ".|." Infix L 5 + PrimBXor{} -> Operator "xor" App L 10 + PrimBNot{} -> Operator "complement" App L 10 + PrimBShiftL{} -> Operator "shiftL" App L 10 + PrimBShiftR{} -> Operator "shiftR" App L 10 + PrimBRotateL{} -> Operator "rotateL" App L 10 + PrimBRotateR{} -> Operator "rotateR" App L 10 + PrimPopCount{} -> Operator "popCount" App L 10 + PrimCountLeadingZeros{} -> Operator "countLeadingZeros" App L 10 + PrimCountTrailingZeros{} -> Operator "countTrailingZeros" App L 10 + PrimBReverse{} -> Operator "bitreverse" App L 10 + PrimBSwap{} -> Operator "byteswap" App L 10 + PrimVBAnd{} -> Operator "vand" App L 10 + PrimVBOr{} -> Operator "vor" App L 10 + PrimVBXor{} -> Operator "vxor" App L 10 + PrimFDiv{} -> Operator (pretty '/') Infix L 7 + PrimRecip{} -> Operator "recip" App L 10 + PrimSin{} -> Operator "sin" App L 10 + PrimCos{} -> Operator "cos" App L 10 + PrimTan{} -> Operator "tan" App L 10 + PrimAsin{} -> Operator "asin" App L 10 + PrimAcos{} -> Operator "acos" App L 10 + PrimAtan{} -> Operator "atan" App L 10 + PrimSinh{} -> Operator "sinh" App L 10 + PrimCosh{} -> Operator "cosh" App L 10 + PrimTanh{} -> Operator "tanh" App L 10 + PrimAsinh{} -> Operator "asinh" App L 10 + PrimAcosh{} -> Operator "acosh" App L 10 + PrimAtanh{} -> Operator "atanh" App L 10 + PrimExpFloating{} -> Operator "exp" App L 10 + PrimSqrt{} -> Operator "sqrt" App L 10 + PrimLog{} -> Operator "log" App L 10 + PrimFPow{} -> Operator "**" Infix R 8 + PrimLogBase{} -> Operator "logBase" App L 10 + PrimTruncate{} -> Operator "truncate" App L 10 + PrimRound{} -> Operator "round" App L 10 + PrimFloor{} -> Operator "floor" App L 10 + PrimCeiling{} -> Operator "ceiling" App L 10 + PrimAtan2{} -> Operator "atan2" App L 10 + PrimIsNaN{} -> Operator "isNaN" App L 10 + PrimIsInfinite{} -> Operator "isInfinite" App L 10 + PrimLt{} -> Operator "<" Infix N 4 + PrimGt{} -> Operator ">" Infix N 4 + PrimLtEq{} -> Operator "<=" Infix N 4 + PrimGtEq{} -> Operator ">=" Infix N 4 + PrimEq{} -> Operator "==" Infix N 4 + PrimNEq{} -> Operator "/=" Infix N 4 + PrimMin{} -> Operator "min" App L 10 + PrimMax{} -> Operator "max" App L 10 + PrimVMin{} -> Operator "minimum" App L 10 + PrimVMax{} -> Operator "maximum" App L 10 + PrimLAnd{} -> Operator "&&" Infix R 3 + PrimLOr{} -> Operator "||" Infix R 2 + PrimLNot{} -> Operator "not" App L 10 + PrimVLAnd{} -> Operator "and" App L 10 + PrimVLOr{} -> Operator "or" App L 10 + PrimFromIntegral{} -> Operator "fromIntegral" App L 10 + PrimToFloating{} -> Operator "toFloating" App L 10 + PrimToBool{} -> Operator "toBool" App L 10 + PrimFromBool{} -> Operator "fromBool" App L 10 -- Environments diff --git a/src/Data/Array/Accelerate/Representation/Array.hs b/src/Data/Array/Accelerate/Representation/Array.hs index ee02c0ad5..e36499589 100644 --- a/src/Data/Array/Accelerate/Representation/Array.hs +++ b/src/Data/Array/Accelerate/Representation/Array.hs @@ -93,7 +93,7 @@ arraysRpair a b = TupRunit `TupRpair` TupRsingle a `TupRpair` TupRsingle b -- allocateArray :: ArrayR (Array sh e) -> sh -> IO (Array sh e) allocateArray (ArrayR shR eR) sh = do - adata <- newArrayData eR (size shR sh) + adata <- newArrayData eR (fromIntegral (size shR sh)) return $! Array sh adata -- | Create an array from its representation function, applied at each @@ -109,13 +109,13 @@ fromFunction repr sh f = unsafePerformIO $! fromFunctionM repr sh (return . f) fromFunctionM :: ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e) fromFunctionM (ArrayR shR eR) sh f = do let !n = size shR sh - arr <- newArrayData eR n + arr <- newArrayData eR (fromIntegral n) -- let write !i | i >= n = return () | otherwise = do v <- f (fromIndex shR sh i) - writeArrayData eR arr i v + writeArrayData eR arr (fromIntegral i) v write (i+1) -- write 0 @@ -130,7 +130,7 @@ fromList (ArrayR shR eR) sh xs = adata `seq` Array sh adata -- Assume the array is in dense row-major order. This is safe because -- otherwise backends would not be able to directly memcpy. -- - !n = size shR sh + !n = fromIntegral (size shR sh) (adata, _) = runArrayData @e $ do arr <- newArrayData eR n let go !i _ | i >= n = return () @@ -149,21 +149,21 @@ toList (ArrayR shR eR) (Array sh adata) = go 0 -- Assume underling array is in row-major order. This is safe because -- otherwise backends would not be able to directly memcpy. -- - !n = size shR sh + !n = fromIntegral (size shR sh) go !i | i >= n = [] | otherwise = indexArrayData eR adata i : go (i+1) -concatVectors :: forall e. TypeR e -> [Vector e] -> Vector e -concatVectors tR vs = adata `seq` Array ((), len) adata - where - offsets = scanl (+) 0 (map (size dim1 . shape) vs) - len = last offsets - (adata, _) = runArrayData @e $ do - arr <- newArrayData tR len - sequence_ [ writeArrayData tR arr (i + k) (indexArrayData tR ad i) - | (Array ((), n) ad, k) <- vs `zip` offsets - , i <- [0 .. n - 1] ] - return (arr, undefined) +-- concatVectors :: forall e. TypeR e -> [Vector e] -> Vector e +-- concatVectors tR vs = adata `seq` Array ((), len) adata +-- where +-- offsets = scanl (+) 0 (map (size dim1 . shape) vs) +-- len = last offsets +-- (adata, _) = runArrayData @e $ do +-- arr <- newArrayData tR len +-- sequence_ [ writeArrayData tR arr (i + k) (indexArrayData tR ad i) +-- | (Array ((), n) ad, k) <- vs `zip` offsets +-- , i <- [0 .. n - 1] ] +-- return (arr, undefined) shape :: Array sh e -> sh shape (Array sh _) = sh @@ -176,14 +176,14 @@ reshape shR sh shR' (Array sh' adata) (!) :: (ArrayR (Array sh e), Array sh e) -> sh -> e (!) = uncurry indexArray -(!!) :: (TypeR e, Array sh e) -> Int -> e +(!!) :: (TypeR e, Array sh e) -> INT -> e (!!) = uncurry linearIndexArray indexArray :: ArrayR (Array sh e) -> Array sh e -> sh -> e -indexArray (ArrayR shR adR) (Array sh adata) ix = indexArrayData adR adata (toIndex shR sh ix) +indexArray (ArrayR shR adR) (Array sh adata) ix = indexArrayData adR adata (fromIntegral (toIndex shR sh ix)) -linearIndexArray :: TypeR e -> Array sh e -> Int -> e -linearIndexArray adR (Array _ adata) = indexArrayData adR adata +linearIndexArray :: TypeR e -> Array sh e -> INT -> e +linearIndexArray adR (Array _ adata) = indexArrayData adR adata . fromIntegral showArray :: (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String showArray f arrR@(ArrayR shR _) arr@(Array sh _) = case shR of @@ -212,9 +212,11 @@ showMatrix f (ArrayR _ arrR) arr@(Array sh _) | rows * cols == 0 = "[]" | otherwise = "\n [" ++ ppMat 0 0 where - (((), rows), cols) = sh - lengths = U.generate (rows*cols) (\i -> length (f (linearIndexArray arrR arr i) "")) - widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c)))) + (((), ri), ci) = sh + rows = fromIntegral ri + cols = fromIntegral ci + lengths = U.generate (rows*cols) (\i -> length (f (linearIndexArray arrR arr (fromIntegral i)) "")) + widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c)))) -- ppMat :: Int -> Int -> String ppMat !r !c | c >= cols = ppMat (r+1) 0 @@ -224,7 +226,7 @@ showMatrix f (ArrayR _ arrR) arr@(Array sh _) !l = lengths U.! i !w = widths U.! c !pad = 1 - cell = replicate (w-l+pad) ' ' ++ f (linearIndexArray arrR arr i) "" + cell = replicate (w-l+pad) ' ' ++ f (linearIndexArray arrR arr (fromIntegral i)) "" -- before | r > 0 && c == 0 = "\n " @@ -289,7 +291,7 @@ showsArrays repr arrs = go 0 repr arrs needsParens repr'@(TupRpair _ _) as = isJust $ extractTuple repr' as needsParens _ _ = True -reduceRank :: ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e) +reduceRank :: ArrayR (Array (sh, INT) e) -> ArrayR (Array sh e) reduceRank (ArrayR (ShapeRsnoc shR) aeR) = ArrayR shR aeR rnfArray :: ArrayR a -> a -> () @@ -315,8 +317,7 @@ liftArray :: forall sh e. ArrayR (Array sh e) -> Array sh e -> CodeQ (Array sh e liftArray (ArrayR shR adR) (Array sh adata) = [|| Array $$(liftElt (shapeType shR) sh) $$(liftArrayData sz adR adata) ||] `at` [t| Array $(liftTypeQ (shapeType shR)) $(liftTypeQ adR) |] where - sz :: Int - sz = size shR sh + sz = fromIntegral (size shR sh) at :: CodeQ t -> Q Type -> CodeQ t at e t = unsafeCodeCoerce $ sigE (unTypeCode e) t diff --git a/src/Data/Array/Accelerate/Representation/Elt.hs b/src/Data/Array/Accelerate/Representation/Elt.hs index d72cd7165..af2ab2d64 100644 --- a/src/Data/Array/Accelerate/Representation/Elt.hs +++ b/src/Data/Array/Accelerate/Representation/Elt.hs @@ -1,4 +1,5 @@ {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} @@ -18,14 +19,16 @@ module Data.Array.Accelerate.Representation.Elt import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Type +import Data.Primitive.Bit import Data.Primitive.Vec import Control.Monad.ST -import Data.List ( intercalate ) import Data.Primitive.ByteArray -import Foreign.Storable import Language.Haskell.TH.Extra +import GHC.TypeLits +import GHC.Base + undefElt :: TypeR t -> t undefElt = tuple @@ -36,77 +39,72 @@ undefElt = tuple tuple (TupRsingle t) = scalar t scalar :: ScalarType t -> t - scalar (SingleScalarType t) = single t - scalar (VectorScalarType t) = vector t - - vector :: VectorType t -> t - vector (VectorType n t) = runST $ do - mba <- newByteArray (n * bytesElt (TupRsingle (SingleScalarType t))) - ByteArray ba# <- unsafeFreezeByteArray mba - return (Vec ba#) - - single :: SingleType t -> t - single (NumSingleType t) = num t + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> t + bit TypeBit = Bit False + bit (TypeMask n) = + let bytes = quot (fromInteger (natVal' n) + 7) 8 + in runST $ do + mba <- newByteArray bytes + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# num :: NumType t -> t num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType t -> t - integral TypeInt = 0 - integral TypeInt8 = 0 - integral TypeInt16 = 0 - integral TypeInt32 = 0 - integral TypeInt64 = 0 - integral TypeWord = 0 - integral TypeWord8 = 0 - integral TypeWord16 = 0 - integral TypeWord32 = 0 - integral TypeWord64 = 0 + integral = \case + SingleIntegralType t -> single t + VectorIntegralType n t -> vector n t + where + single :: SingleIntegralType t -> t + single TypeInt8 = 0 + single TypeInt16 = 0 + single TypeInt32 = 0 + single TypeInt64 = 0 + single TypeInt128 = 0 + single TypeWord8 = 0 + single TypeWord16 = 0 + single TypeWord32 = 0 + single TypeWord64 = 0 + single TypeWord128 = 0 + + vector :: KnownNat n => Proxy# n -> SingleIntegralType t -> Vec n t + vector n t = runST $ do + let bits = case t of + TypeInt w -> w + TypeWord w -> w + bytes = max 1 (quot (fromInteger (natVal' n) * bits) 8) + mba <- newAlignedPinnedByteArray bytes 16 + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# floating :: FloatingType t -> t - floating TypeHalf = 0 - floating TypeFloat = 0 - floating TypeDouble = 0 - -bytesElt :: TypeR e -> Int -bytesElt = tuple - where - tuple :: TypeR t -> Int - tuple TupRunit = 0 - tuple (TupRpair ta tb) = tuple ta + tuple tb - tuple (TupRsingle t) = scalar t - - scalar :: ScalarType t -> Int - scalar (SingleScalarType t) = single t - scalar (VectorScalarType t) = vector t - - vector :: VectorType t -> Int - vector (VectorType n t) = n * single t - - single :: SingleType t -> Int - single (NumSingleType t) = num t - - num :: NumType t -> Int - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t + floating = \case + SingleFloatingType t -> single t + VectorFloatingType n t -> vector n t + where + single :: SingleFloatingType t -> t + single TypeFloat16 = 0 + single TypeFloat32 = 0 + single TypeFloat64 = 0 + single TypeFloat128 = 0 + + vector :: KnownNat n => Proxy# n -> SingleFloatingType t -> Vec n t + vector n t = runST $ do + let bits = case t of + TypeFloat16 -> 16 + TypeFloat32 -> 32 + TypeFloat64 -> 64 + TypeFloat128 -> 128 + bytes = max 1 (quot (fromInteger (natVal' n) * bits) 8) + mba <- newAlignedPinnedByteArray bytes 16 + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# - integral :: IntegralType t -> Int - integral TypeInt = sizeOf (undefined::Int) - integral TypeInt8 = 1 - integral TypeInt16 = 2 - integral TypeInt32 = 4 - integral TypeInt64 = 8 - integral TypeWord = sizeOf (undefined::Word) - integral TypeWord8 = 1 - integral TypeWord16 = 2 - integral TypeWord32 = 4 - integral TypeWord64 = 8 - - floating :: FloatingType t -> Int - floating TypeHalf = 2 - floating TypeFloat = 4 - floating TypeDouble = 8 showElt :: TypeR e -> e -> String showElt t v = showsElt t v "" @@ -120,38 +118,62 @@ showsElt = tuple tuple (TupRsingle tp) val = scalar tp val scalar :: ScalarType e -> e -> ShowS - scalar (SingleScalarType t) e = single t e - scalar (VectorScalarType t) e = vector t e + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - single :: SingleType e -> e -> ShowS - single (NumSingleType t) = num t + bit :: BitType e -> e -> ShowS + bit TypeBit = shows + bit TypeMask{} = shows . BitMask num :: NumType e -> e -> ShowS num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType e -> e -> ShowS - integral TypeInt = shows - integral TypeInt8 = shows - integral TypeInt16 = shows - integral TypeInt32 = shows - integral TypeInt64 = shows - integral TypeWord = shows - integral TypeWord8 = shows - integral TypeWord16 = shows - integral TypeWord32 = shows - integral TypeWord64 = shows + integral = \case + SingleIntegralType t -> single t + VectorIntegralType _ t -> vector t + where + single :: SingleIntegralType t -> t -> ShowS + single TypeInt8 = shows + single TypeInt16 = shows + single TypeInt32 = shows + single TypeInt64 = shows + single TypeInt128 = shows + single TypeWord8 = shows + single TypeWord16 = shows + single TypeWord32 = shows + single TypeWord64 = shows + single TypeWord128 = shows + + vector :: KnownNat n => SingleIntegralType t -> Vec n t -> ShowS + vector TypeInt8 = shows + vector TypeInt16 = shows + vector TypeInt32 = shows + vector TypeInt64 = shows + vector TypeInt128 = shows + vector TypeWord8 = shows + vector TypeWord16 = shows + vector TypeWord32 = shows + vector TypeWord64 = shows + vector TypeWord128 = shows floating :: FloatingType e -> e -> ShowS - floating TypeHalf = shows - floating TypeFloat = shows - floating TypeDouble = shows - - vector :: VectorType (Vec n a) -> Vec n a -> ShowS - vector (VectorType _ s) vec - | SingleDict <- singleDict s - = showString - $ "<" ++ intercalate ", " ((\v -> single s v "") <$> listOfVec vec) ++ ">" + floating = \case + SingleFloatingType t -> single t + VectorFloatingType _ t -> vector t + where + single :: SingleFloatingType t -> t -> ShowS + single TypeFloat16 = shows + single TypeFloat32 = shows + single TypeFloat64 = shows + single TypeFloat128 = shows + + vector :: KnownNat n => SingleFloatingType t -> Vec n t -> ShowS + vector TypeFloat16 = shows + vector TypeFloat32 = shows + vector TypeFloat64 = shows + vector TypeFloat128 = shows liftElt :: TypeR t -> t -> CodeQ t liftElt TupRunit () = [|| () ||] diff --git a/src/Data/Array/Accelerate/Representation/Shape.hs b/src/Data/Array/Accelerate/Representation/Shape.hs index fa3651c03..fb68d71ab 100644 --- a/src/Data/Array/Accelerate/Representation/Shape.hs +++ b/src/Data/Array/Accelerate/Representation/Shape.hs @@ -23,14 +23,12 @@ import Data.Array.Accelerate.Representation.Type import Language.Haskell.TH.Extra import Prelude hiding ( zip ) -import GHC.Base ( quotInt, remInt ) - -- | Shape and index representations as nested pairs -- data ShapeR sh where ShapeRz :: ShapeR () - ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int) + ShapeRsnoc :: ShapeR sh -> ShapeR (sh, INT) -- | Nicely format a shape as a string -- @@ -40,9 +38,9 @@ showShape shr = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList sh -- Synonyms for common shape types -- type DIM0 = () -type DIM1 = ((), Int) -type DIM2 = (((), Int), Int) -type DIM3 = ((((), Int), Int), Int) +type DIM1 = ((), INT) +type DIM2 = (((), INT), INT) +type DIM3 = ((((), INT), INT), INT) dim0 :: ShapeR DIM0 dim0 = ShapeRz @@ -58,13 +56,13 @@ dim3 = ShapeRsnoc dim2 -- | Number of dimensions of a /shape/ or /index/ (>= 0) -- -rank :: ShapeR sh -> Int +rank :: ShapeR sh -> INT rank ShapeRz = 0 rank (ShapeRsnoc shr) = rank shr + 1 -- | Total number of elements in an array of the given shape -- -size :: ShapeR sh -> sh -> Int +size :: ShapeR sh -> sh -> INT size ShapeRz () = 1 size (ShapeRsnoc shr) (sh, sz) | sz <= 0 = 0 @@ -86,7 +84,7 @@ intersect = zip min union :: ShapeR sh -> sh -> sh -> sh union = zip max -zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh +zip :: (INT -> INT -> INT) -> ShapeR sh -> sh -> sh -> sh zip _ ShapeRz () () = () zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b) @@ -99,25 +97,25 @@ eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' -- representation of the array (first argument is the /shape/, second -- argument is the /index/). -- -toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int -toIndex ShapeRz () () = 0 +toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> INT +toIndex ShapeRz () () = 0 toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) = indexCheck i sz $ toIndex shr sh ix * sz + i -- | Inverse of 'toIndex' -- -fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh -fromIndex ShapeRz () _ = () +fromIndex :: HasCallStack => ShapeR sh -> sh -> INT -> sh +fromIndex ShapeRz () _ = () fromIndex (ShapeRsnoc shr) (sh, sz) i - = (fromIndex shr sh (i `quotInt` sz), r) + = (fromIndex shr sh (i `quot` sz), r) -- If we assume that the index is in range, there is no point in computing -- the remainder for the highest dimension since i < sz must hold. -- where r = case shr of -- Check if rank of shr is 0 ShapeRz -> indexCheck i sz i - _ -> i `remInt` sz + _ -> i `rem` sz -- | Iterate through the entire shape, applying the function in the second -- argument; third argument combines results and fourth is an initial value @@ -157,13 +155,13 @@ shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh i -- | Convert a shape or index into its list of dimensions -- -shapeToList :: ShapeR sh -> sh -> [Int] +shapeToList :: ShapeR sh -> sh -> [INT] shapeToList ShapeRz () = [] shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh -- | Convert a list of dimensions into a shape -- -listToShape :: HasCallStack => ShapeR sh -> [Int] -> sh +listToShape :: HasCallStack => ShapeR sh -> [INT] -> sh listToShape shr ds = case listToShape' shr ds of Just sh -> sh @@ -171,17 +169,14 @@ listToShape shr ds = -- | Attempt to convert a list of dimensions into a shape -- -listToShape' :: ShapeR sh -> [Int] -> Maybe sh +listToShape' :: ShapeR sh -> [INT] -> Maybe sh listToShape' ShapeRz [] = Just () listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs listToShape' _ _ = Nothing shapeType :: ShapeR sh -> TypeR sh shapeType ShapeRz = TupRunit -shapeType (ShapeRsnoc shr) = - shapeType shr - `TupRpair` - TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) +shapeType (ShapeRsnoc shR) = shapeType shR `TupRpair` TupRsingle scalarType rnfShape :: ShapeR sh -> sh -> () rnfShape ShapeRz () = () diff --git a/src/Data/Array/Accelerate/Representation/Slice.hs b/src/Data/Array/Accelerate/Representation/Slice.hs index dee059a37..8e0246ad1 100644 --- a/src/Data/Array/Accelerate/Representation/Slice.hs +++ b/src/Data/Array/Accelerate/Representation/Slice.hs @@ -19,6 +19,7 @@ module Data.Array.Accelerate.Representation.Slice where import Data.Array.Accelerate.Representation.Shape +import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -38,15 +39,15 @@ instance Slice () where sliceIndex = SliceNil instance Slice sl => Slice (sl, ()) where - type SliceShape (sl, ()) = (SliceShape sl, Int) + type SliceShape (sl, ()) = (SliceShape sl, INT) type CoSliceShape (sl, ()) = CoSliceShape sl - type FullShape (sl, ()) = (FullShape sl, Int) + type FullShape (sl, ()) = (FullShape sl, INT) sliceIndex = SliceAll (sliceIndex @sl) -instance Slice sl => Slice (sl, Int) where - type SliceShape (sl, Int) = SliceShape sl - type CoSliceShape (sl, Int) = (CoSliceShape sl, Int) - type FullShape (sl, Int) = (FullShape sl, Int) +instance Slice sl => Slice (sl, INT) where + type SliceShape (sl, INT) = SliceShape sl + type CoSliceShape (sl, INT) = (CoSliceShape sl, INT) + type FullShape (sl, INT) = (FullShape sl, INT) sliceIndex = SliceFixed (sliceIndex @sl) -- |Generalised array index, which may index only in a subset of the dimensions @@ -54,8 +55,8 @@ instance Slice sl => Slice (sl, Int) where -- data SliceIndex ix slice coSlice sliceDim where SliceNil :: SliceIndex () () () () - SliceAll :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, Int) co (dim, Int) - SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, Int) slice (co, Int) (dim, Int) + SliceAll :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, INT) co (dim, INT) + SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, INT) slice (co, INT) (dim, INT) instance Show (SliceIndex ix slice coSlice sliceDim) where show SliceNil = "SliceNil" diff --git a/src/Data/Array/Accelerate/Representation/Stencil.hs b/src/Data/Array/Accelerate/Representation/Stencil.hs index dd546721c..804892800 100644 --- a/src/Data/Array/Accelerate/Representation/Stencil.hs +++ b/src/Data/Array/Accelerate/Representation/Stencil.hs @@ -25,6 +25,7 @@ module Data.Array.Accelerate.Representation.Stencil ( import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Type import Language.Haskell.TH.Extra @@ -40,14 +41,14 @@ data StencilR sh e pat where StencilRtup3 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 - -> StencilR (sh, Int) e (Tup3 pat1 pat2 pat3) + -> StencilR (sh, INT) e (Tup3 pat1 pat2 pat3) StencilRtup5 :: StencilR sh e pat1 -> StencilR sh e pat2 -> StencilR sh e pat3 -> StencilR sh e pat4 -> StencilR sh e pat5 - -> StencilR (sh, Int) e (Tup5 pat1 pat2 pat3 pat4 pat5) + -> StencilR (sh, INT) e (Tup5 pat1 pat2 pat3 pat4 pat5) StencilRtup7 :: StencilR sh e pat1 -> StencilR sh e pat2 @@ -56,7 +57,7 @@ data StencilR sh e pat where -> StencilR sh e pat5 -> StencilR sh e pat6 -> StencilR sh e pat7 - -> StencilR (sh, Int) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) + -> StencilR (sh, INT) e (Tup7 pat1 pat2 pat3 pat4 pat5 pat6 pat7) StencilRtup9 :: StencilR sh e pat1 -> StencilR sh e pat2 @@ -67,7 +68,7 @@ data StencilR sh e pat where -> StencilR sh e pat7 -> StencilR sh e pat8 -> StencilR sh e pat9 - -> StencilR (sh, Int) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) + -> StencilR (sh, INT) e (Tup9 pat1 pat2 pat3 pat4 pat5 pat6 pat7 pat8 pat9) stencilEltR :: StencilR sh e pat -> TypeR e stencilEltR (StencilRunit3 t) = t @@ -123,7 +124,7 @@ stencilHalo = go' go :: StencilR sh e stencil -> sh go = snd . go' - cons :: ShapeR sh -> Int -> sh -> (sh, Int) + cons :: ShapeR sh -> INT -> sh -> (sh, INT) cons ShapeRz ix () = ((), ix) cons (ShapeRsnoc shr) ix (sh, sz) = (cons shr ix sh, sz) diff --git a/src/Data/Array/Accelerate/Representation/Tag.hs b/src/Data/Array/Accelerate/Representation/Tag.hs index ed7e07e80..eace299f6 100644 --- a/src/Data/Array/Accelerate/Representation/Tag.hs +++ b/src/Data/Array/Accelerate/Representation/Tag.hs @@ -1,5 +1,6 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Tag @@ -24,45 +25,74 @@ import Language.Haskell.TH.Extra -- type TAG = Word8 +data TagType t where + TagBit :: TagType Bit + TagWord8 :: TagType Word8 + TagWord16 :: TagType Word16 + -- | This structure both witnesses the layout of our representation types -- (as TupR does) and represents a complete path of pattern matching -- through this type. It indicates which fields of the structure represent --- the union tags (TagRtag) or store undefined values (TagRundef). +-- the union tags (TagRcon or TagRenum) or store undefined values (TagRundef). -- --- The function 'eltTags' produces all valid paths through the type. For +-- The function 'tagsR' produces all valid paths through the type. For -- example the type '(Bool,Bool)' produces the following: -- --- ghci> putStrLn . unlines . map show $ eltTags @(Bool,Bool) --- (((),(0#,())),(0#,())) -- (False, False) --- (((),(0#,())),(1#,())) -- (False, True) --- (((),(1#,())),(0#,())) -- (True, False) --- (((),(1#,())),(1#,())) -- (True, True) +-- ghci> putStrLn . unlines . map show $ tagsR @(Bool,Bool) +-- (((),0#),0#) -- (False, False) +-- (((),0#),1#) -- (False, True) +-- (((),1#),0#) -- (True, False) +-- (((),1#),1#) -- (True, True) -- data TagR a where TagRunit :: TagR () TagRsingle :: ScalarType a -> TagR a TagRundef :: ScalarType a -> TagR a - TagRtag :: TAG -> TagR a -> TagR (TAG, a) TagRpair :: TagR a -> TagR b -> TagR (a, b) + TagRcon :: TagType t -> t -> TagR a -> TagR (t, a) -- data constructors + TagRenum :: TagType t -> t -> TagR t -- enumerations instance Show (TagR a) where show TagRunit = "()" show TagRsingle{} = "." show TagRundef{} = "undef" - show (TagRtag v t) = "(" ++ show v ++ "#," ++ show t ++ ")" - show (TagRpair ta tb) = "(" ++ show ta ++ "," ++ show tb ++ ")" + show (TagRpair a b) = "(" ++ show a ++ "," ++ show b ++ ")" + show (TagRcon tR t e) = "(" ++ showTag tR t ++ "#," ++ show e ++ ")" + show (TagRenum tR t) = showTag tR t ++ "#" + +showTag :: TagType t -> t -> String +showTag TagBit = show +showTag TagWord8 = show +showTag TagWord16 = show rnfTag :: TagR a -> () -rnfTag TagRunit = () -rnfTag (TagRsingle t) = rnfScalarType t -rnfTag (TagRundef t) = rnfScalarType t -rnfTag (TagRtag v t) = v `seq` rnfTag t -rnfTag (TagRpair ta tb) = rnfTag ta `seq` rnfTag tb +rnfTag TagRunit = () +rnfTag (TagRsingle e) = rnfScalarType e +rnfTag (TagRundef e) = rnfScalarType e +rnfTag (TagRpair a b) = rnfTag a `seq` rnfTag b +rnfTag (TagRenum tR t) = rnfTagType tR `seq` t `seq` () +rnfTag (TagRcon tR t e) = rnfTagType tR `seq` t `seq` rnfTag e + +rnfTagType :: TagType t -> () +rnfTagType TagBit = () +rnfTagType TagWord8 = () +rnfTagType TagWord16 = () + +liftTagR :: TagR a -> CodeQ (TagR a) +liftTagR TagRunit = [|| TagRunit ||] +liftTagR (TagRsingle e) = [|| TagRsingle $$(liftScalarType e) ||] +liftTagR (TagRundef e) = [|| TagRundef $$(liftScalarType e) ||] +liftTagR (TagRpair a b) = [|| TagRpair $$(liftTagR a) $$(liftTagR b) ||] +liftTagR (TagRcon tR t e) = [|| TagRcon $$(liftTagType tR) $$(liftTag tR t) $$(liftTagR e) ||] +liftTagR (TagRenum tR t) = [|| TagRenum $$(liftTagType tR) $$(liftTag tR t) ||] + +liftTag :: TagType t -> t -> CodeQ t +liftTag TagBit (Bit x) = [|| Bit x ||] +liftTag TagWord8 x = [|| x ||] +liftTag TagWord16 x = [|| x ||] -liftTag :: TagR a -> CodeQ (TagR a) -liftTag TagRunit = [|| TagRunit ||] -liftTag (TagRsingle t) = [|| TagRsingle $$(liftScalarType t) ||] -liftTag (TagRundef t) = [|| TagRundef $$(liftScalarType t) ||] -liftTag (TagRtag v t) = [|| TagRtag v $$(liftTag t) ||] -liftTag (TagRpair ta tb) = [|| TagRpair $$(liftTag ta) $$(liftTag tb) ||] +liftTagType :: TagType t -> CodeQ (TagType t) +liftTagType TagBit = [|| TagBit ||] +liftTagType TagWord8 = [|| TagWord8 ||] +liftTagType TagWord16 = [|| TagWord16 ||] diff --git a/src/Data/Array/Accelerate/Representation/Type.hs b/src/Data/Array/Accelerate/Representation/Type.hs index 9b154133f..b275b58a0 100644 --- a/src/Data/Array/Accelerate/Representation/Type.hs +++ b/src/Data/Array/Accelerate/Representation/Type.hs @@ -26,6 +26,8 @@ import Data.Primitive.Vec import Formatting import Language.Haskell.TH.Extra +import GHC.TypeLits + -- | Both arrays (Acc) and expressions (Exp) are represented as nested -- pairs consisting of: @@ -45,7 +47,11 @@ data TupR s a where TupRsingle :: s a -> TupR s a TupRpair :: TupR s a -> TupR s b -> TupR s (a, b) -deriving instance (forall a. Show (s a)) => Show (TupR s t) +instance (forall a. Show (s a)) => Show (TupR s t) where + show = \case + TupRunit -> "()" + TupRsingle t -> show t + TupRpair a b -> "(" ++ show a ++ "," ++ show b ++ ")" formatTypeR :: Format r (TypeR a -> r) formatTypeR = later $ \case @@ -82,35 +88,45 @@ liftTypeQ = tuple tuple (TupRsingle t) = scalar t scalar :: ScalarType t -> TypeQ - scalar (SingleScalarType t) = single t - scalar (VectorScalarType t) = vector t - - vector :: VectorType (Vec n a) -> TypeQ - vector (VectorType n t) = [t| Vec $(litT (numTyLit (toInteger n))) $(single t) |] + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t - single :: SingleType t -> TypeQ - single (NumSingleType t) = num t + bit :: BitType t -> TypeQ + bit TypeBit = [t| Bit |] + bit (TypeMask n) = [t| Vec $(litT (numTyLit (natVal' n))) Bit |] num :: NumType t -> TypeQ num (IntegralNumType t) = integral t num (FloatingNumType t) = floating t integral :: IntegralType t -> TypeQ - integral TypeInt = [t| Int |] - integral TypeInt8 = [t| Int8 |] - integral TypeInt16 = [t| Int16 |] - integral TypeInt32 = [t| Int32 |] - integral TypeInt64 = [t| Int64 |] - integral TypeWord = [t| Word |] - integral TypeWord8 = [t| Word8 |] - integral TypeWord16 = [t| Word16 |] - integral TypeWord32 = [t| Word32 |] - integral TypeWord64 = [t| Word64 |] + integral = \case + SingleIntegralType t -> [t| $(single t) |] + VectorIntegralType n t -> [t| Vec $(litT (numTyLit (natVal' n))) $(single t) |] + where + single :: SingleIntegralType t -> TypeQ + single TypeInt8 = [t| Int8 |] + single TypeInt16 = [t| Int16 |] + single TypeInt32 = [t| Int32 |] + single TypeInt64 = [t| Int64 |] + single TypeInt128 = [t| Int128 |] + single TypeWord8 = [t| Word8 |] + single TypeWord16 = [t| Word16 |] + single TypeWord32 = [t| Word32 |] + single TypeWord64 = [t| Word64 |] + single TypeWord128 = [t| Word128 |] floating :: FloatingType t -> TypeQ - floating TypeHalf = [t| Half |] - floating TypeFloat = [t| Float |] - floating TypeDouble = [t| Double |] + floating = \case + SingleFloatingType t -> [t| $(single t) |] + VectorFloatingType n t -> [t| Vec $(litT (numTyLit (natVal' n))) $(single t) |] + where + single :: SingleFloatingType t -> TypeQ + single TypeFloat16 = [t| Half |] + single TypeFloat32 = [t| Float |] + single TypeFloat64 = [t| Double |] + single TypeFloat128 = [t| Float128 |] + runQ $ let diff --git a/src/Data/Array/Accelerate/Representation/Vec.hs b/src/Data/Array/Accelerate/Representation/Vec.hs index 35eac3b6c..378a640f6 100644 --- a/src/Data/Array/Accelerate/Representation/Vec.hs +++ b/src/Data/Array/Accelerate/Representation/Vec.hs @@ -1,10 +1,5 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Vec @@ -19,77 +14,99 @@ module Data.Array.Accelerate.Representation.Vec where -import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Type -import Data.Primitive.Vec +import Data.Array.Accelerate.Type -import Control.Monad.ST -import Data.Primitive.ByteArray -import Data.Primitive.Types -import Language.Haskell.TH.Extra +import Data.Primitive.Bit as Prim -import GHC.Base ( Int(..), Int#, (-#) ) -import GHC.TypeNats +import qualified GHC.Exts as GHC --- | Declares the size of a SIMD vector and the type of its elements. This --- data type is used to denote the relation between a vector type (Vec --- n single) with its tuple representation (tuple). Conversions between --- those types are exposed through 'pack' and 'unpack'. --- -data VecR (n :: Nat) single tuple where - VecRnil :: SingleType s -> VecR 0 s () - VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s) - -vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s) -vecRvector = uncurry VectorType . go +toList :: TypeR v -> TypeR a -> v -> [a] +toList = go where - go :: VecR n s tuple -> (Int, SingleType s) - go (VecRnil tp) = (0, tp) - go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp) + go :: TypeR v -> TypeR t -> v -> [t] + go TupRunit TupRunit _ = repeat () + go (TupRpair va vb) (TupRpair ta tb) (a, b) = zip (go va ta a) (go vb tb b) + go (TupRsingle v) (TupRsingle t) xs = scalar v t xs + go _ _ _ = internalError "unexpected vector encoding" -vecRtuple :: VecR n s tuple -> TypeR tuple -vecRtuple = snd . go - where - go :: VecR n s tuple -> (SingleType s, TypeR tuple) - go (VecRnil tp) = (tp, TupRunit) - go (VecRsucc vec) | (tp, tuple) <- go vec = (tp, TupRpair tuple (TupRsingle (SingleScalarType tp))) - -pack :: forall n single tuple. KnownNat n => VecR n single tuple -> tuple -> Vec n single -pack vecR tuple - | VectorType n single <- vecRvector vecR - , SingleDict <- singleDict single - = runST $ do - mba <- newByteArray (n * sizeOf (undefined :: single)) - go (n - 1) vecR tuple mba - ByteArray ba# <- unsafeFreezeByteArray mba - return $! Vec ba# - where - go :: Prim single => Int -> VecR n' single tuple' -> tuple' -> MutableByteArray s -> ST s () - go _ (VecRnil _) () _ = return () - go i (VecRsucc r) (xs, x) mba = do - writeByteArray mba i x - go (i - 1) r xs mba - -unpack :: forall n single tuple. KnownNat n => VecR n single tuple -> Vec n single -> tuple -unpack vecR (Vec ba#) - | VectorType n single <- vecRvector vecR - , (I# n#) <- n - , SingleDict <- singleDict single - = go (n# -# 1#) vecR + scalar :: ScalarType v -> ScalarType t -> v -> [t] + scalar (NumScalarType v) (NumScalarType t) = num v t + scalar (BitScalarType v) (BitScalarType t) = bit v t + scalar _ _ = internalError "unexpected vector encoding" + + bit :: BitType v -> BitType t -> v -> [t] + bit (TypeMask _) TypeBit = GHC.toList . Prim.BitMask + bit _ _ = internalError "unexpected vector encoding" + + num :: NumType v -> NumType t -> v -> [t] + num (IntegralNumType v) (IntegralNumType t) = integral v t + num (FloatingNumType v) (FloatingNumType t) = floating v t + num _ _ = internalError "unexpected vector encoding" + + integral :: IntegralType v -> IntegralType t -> v -> [t] + integral (VectorIntegralType _ TypeInt8) (SingleIntegralType TypeInt8) = GHC.toList + integral (VectorIntegralType _ TypeInt16) (SingleIntegralType TypeInt16) = GHC.toList + integral (VectorIntegralType _ TypeInt32) (SingleIntegralType TypeInt32) = GHC.toList + integral (VectorIntegralType _ TypeInt64) (SingleIntegralType TypeInt64) = GHC.toList + integral (VectorIntegralType _ TypeInt128) (SingleIntegralType TypeInt128) = GHC.toList + integral (VectorIntegralType _ TypeWord8) (SingleIntegralType TypeWord8) = GHC.toList + integral (VectorIntegralType _ TypeWord16) (SingleIntegralType TypeWord16) = GHC.toList + integral (VectorIntegralType _ TypeWord32) (SingleIntegralType TypeWord32) = GHC.toList + integral (VectorIntegralType _ TypeWord64) (SingleIntegralType TypeWord64) = GHC.toList + integral (VectorIntegralType _ TypeWord128) (SingleIntegralType TypeWord128) = GHC.toList + integral _ _ = internalError "unexpected vector encoding" + + floating :: FloatingType v -> FloatingType t -> v -> [t] + floating (VectorFloatingType _ TypeFloat16) (SingleFloatingType TypeFloat16) = GHC.toList + floating (VectorFloatingType _ TypeFloat32) (SingleFloatingType TypeFloat32) = GHC.toList + floating (VectorFloatingType _ TypeFloat64) (SingleFloatingType TypeFloat64) = GHC.toList + floating (VectorFloatingType _ TypeFloat128) (SingleFloatingType TypeFloat128) = GHC.toList + floating _ _ = internalError "unexpected vector encoding" + + +fromList :: TypeR v -> TypeR a -> [a] -> v +fromList = go where - go :: Prim single => Int# -> VecR n' single tuple' -> tuple' - go _ (VecRnil _) = () - go i# (VecRsucc r) = x `seq` xs `seq` (xs, x) - where - xs = go (i# -# 1#) r - x = indexByteArray# ba# i# - -rnfVecR :: VecR n single tuple -> () -rnfVecR (VecRnil tp) = rnfSingleType tp -rnfVecR (VecRsucc vec) = rnfVecR vec - -liftVecR :: VecR n single tuple -> CodeQ (VecR n single tuple) -liftVecR (VecRnil tp) = [|| VecRnil $$(liftSingleType tp) ||] -liftVecR (VecRsucc vec) = [|| VecRsucc $$(liftVecR vec) ||] + go :: TypeR v -> TypeR t -> [t] -> v + go TupRunit TupRunit _ = () + go (TupRpair va vb) (TupRpair ta tb) xs = let (as, bs) = unzip xs in (go va ta as, go vb tb bs) + go (TupRsingle v) (TupRsingle t) xs = scalar v t xs + go _ _ _ = error "unexpected vector encoding" + + scalar :: ScalarType v -> ScalarType t -> [t] -> v + scalar (NumScalarType v) (NumScalarType t) = num v t + scalar (BitScalarType v) (BitScalarType t) = bit v t + scalar _ _ = internalError "unexpected vector encoding" + + bit :: BitType v -> BitType t -> [t] -> v + bit (TypeMask _) TypeBit = Prim.unMask . GHC.fromList + bit _ _ = internalError "unexpected vector encoding" + + num :: NumType v -> NumType t -> [t] -> v + num (IntegralNumType v) (IntegralNumType t) = integral v t + num (FloatingNumType v) (FloatingNumType t) = floating v t + num _ _ = internalError "unexpected vector encoding" + + integral :: IntegralType v -> IntegralType t -> [t] -> v + integral (VectorIntegralType _ TypeInt8) (SingleIntegralType TypeInt8) = GHC.fromList + integral (VectorIntegralType _ TypeInt16) (SingleIntegralType TypeInt16) = GHC.fromList + integral (VectorIntegralType _ TypeInt32) (SingleIntegralType TypeInt32) = GHC.fromList + integral (VectorIntegralType _ TypeInt64) (SingleIntegralType TypeInt64) = GHC.fromList + integral (VectorIntegralType _ TypeInt128) (SingleIntegralType TypeInt128) = GHC.fromList + integral (VectorIntegralType _ TypeWord8) (SingleIntegralType TypeWord8) = GHC.fromList + integral (VectorIntegralType _ TypeWord16) (SingleIntegralType TypeWord16) = GHC.fromList + integral (VectorIntegralType _ TypeWord32) (SingleIntegralType TypeWord32) = GHC.fromList + integral (VectorIntegralType _ TypeWord64) (SingleIntegralType TypeWord64) = GHC.fromList + integral (VectorIntegralType _ TypeWord128) (SingleIntegralType TypeWord128) = GHC.fromList + integral _ _ = internalError "unexpected vector encoding" + + floating :: FloatingType v -> FloatingType t -> [t] -> v + floating (VectorFloatingType _ TypeFloat16) (SingleFloatingType TypeFloat16) = GHC.fromList + floating (VectorFloatingType _ TypeFloat32) (SingleFloatingType TypeFloat32) = GHC.fromList + floating (VectorFloatingType _ TypeFloat64) (SingleFloatingType TypeFloat64) = GHC.fromList + floating (VectorFloatingType _ TypeFloat128) (SingleFloatingType TypeFloat128) = GHC.fromList + floating _ _ = internalError "unexpected vector encoding" diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 8fa577f41..c033c4471 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1,10 +1,11 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -45,45 +46,38 @@ module Data.Array.Accelerate.Smart ( HasArraysR(..), HasTypeR(..), - -- ** Smart constructors for literals + -- ** Constants constant, undef, - -- ** Smart destructors for shapes + -- ** Shapes indexHead, indexTail, - -- ** Smart constructors for constants - mkMinBound, mkMaxBound, mkPi, - mkSin, mkCos, mkTan, - mkAsin, mkAcos, mkAtan, - mkSinh, mkCosh, mkTanh, - mkAsinh, mkAcosh, mkAtanh, - mkExpFloating, mkSqrt, mkLog, - mkFPow, mkLogBase, - mkTruncate, mkRound, mkFloor, mkCeiling, - mkAtan2, - - -- ** Smart constructors for primitive functions - mkAdd, mkSub, mkMul, mkNeg, mkAbs, mkSig, mkQuot, mkRem, mkQuotRem, mkIDiv, mkMod, mkDivMod, - mkBAnd, mkBOr, mkBXor, mkBNot, mkBShiftL, mkBShiftR, mkBRotateL, mkBRotateR, mkPopCount, mkCountLeadingZeros, mkCountTrailingZeros, - mkFDiv, mkRecip, mkLt, mkGt, mkLtEq, mkGtEq, mkEq, mkNEq, mkMax, mkMin, - mkLAnd, mkLOr, mkLNot, mkIsNaN, mkIsInfinite, - - -- ** Smart constructors for type coercion functions - mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), - - -- ** Auxiliary functions - ($$), ($$$), ($$$$), ($$$$$), - ApplyAcc(..), - unAcc, unAccFunction, mkExp, unExp, unExpFunction, unExpBinaryFunction, unPair, mkPairToTuple, + -- ** SIMD vectors + splat, pack, unpack, + extract, mkExtract, + insert, mkInsert, + shuffle, + select, + + -- ** Type coercions + mkBitcast, + mkCoerce, Coerce(..), + mkAcoerce, Acoerce(..), -- ** Miscellaneous - formatPreAccOp, - formatPreExpOp, + ($$), ($$$), ($$$$), ($$$$$), + ApplyAcc(..), + unAcc, unAccFunction, + mkExp, unExp, unExpFunction, unExpBinaryFunction, mkPrimUnary, mkPrimBinary, + unPair, mkPairToTuple, + formatPreAccOp, formatPreExpOp, ) where +import Data.Array.Accelerate.AST ( Direction(..), Message(..), RescaleFactor, PrimBool, PrimMaybe, PrimFun(..), primFunType ) import Data.Array.Accelerate.AST.Idx +import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Elt @@ -92,27 +86,25 @@ import Data.Array.Accelerate.Representation.Slice import Data.Array.Accelerate.Representation.Stencil hiding ( StencilR, stencilR ) import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type -import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Array ( Arrays ) import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) ) +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Representation.Stencil as R import qualified Data.Array.Accelerate.Sugar.Array as Sugar import qualified Data.Array.Accelerate.Sugar.Shape as Sugar -import Data.Array.Accelerate.AST ( Direction(..), Message(..) - , PrimBool, PrimMaybe - , PrimFun(..), primFunType - , PrimConst(..), primConstType ) -import Data.Primitive.Vec +import qualified Data.Primitive.Vec as Prim import Data.Kind import Data.Text.Lazy.Builder -import Formatting +import Formatting hiding ( splat ) +import GHC.Prim import GHC.TypeLits +import GHC.TypeLits.Extra -- Array computations @@ -361,6 +353,11 @@ data PreSmartAcc acc exp as where -> acc arrs2 -> PreSmartAcc acc exp arrs2 + Acoerce :: RescaleFactor + -> TypeR b + -> acc (Array (sh, INT) a) + -> PreSmartAcc acc exp (Array (sh, INT) b) + Use :: ArrayR (Array sh e) -> Array sh e -> PreSmartAcc acc exp (Array sh e) @@ -406,30 +403,30 @@ data PreSmartAcc acc exp as where Fold :: TypeR e -> (SmartExp e -> SmartExp e -> exp e) -> Maybe (exp e) - -> acc (Array (sh, Int) e) + -> acc (Array (sh, INT) e) -> PreSmartAcc acc exp (Array sh e) - FoldSeg :: IntegralType i + FoldSeg :: SingleIntegralType i -> TypeR e -> (SmartExp e -> SmartExp e -> exp e) -> Maybe (exp e) - -> acc (Array (sh, Int) e) + -> acc (Array (sh, INT) e) -> acc (Segments i) - -> PreSmartAcc acc exp (Array (sh, Int) e) + -> PreSmartAcc acc exp (Array (sh, INT) e) Scan :: Direction -> TypeR e -> (SmartExp e -> SmartExp e -> exp e) -> Maybe (exp e) - -> acc (Array (sh, Int) e) - -> PreSmartAcc acc exp (Array (sh, Int) e) + -> acc (Array (sh, INT) e) + -> PreSmartAcc acc exp (Array (sh, INT) e) Scan' :: Direction -> TypeR e -> (SmartExp e -> SmartExp e -> exp e) -> exp e - -> acc (Array (sh, Int) e) - -> PreSmartAcc acc exp (Array (sh, Int) e, Array sh e) + -> acc (Array (sh, INT) e) + -> PreSmartAcc acc exp (Array (sh, INT) e, Array sh e) Permute :: ArrayR (Array sh e) -> (SmartExp e -> SmartExp e -> exp e) @@ -510,24 +507,40 @@ data PreSmartExp acc exp t where -> exp (t1, t2) -> PreSmartExp acc exp t - VecPack :: KnownNat n - => VecR n s tup - -> exp tup - -> PreSmartExp acc exp (Vec n s) + Extract :: ScalarType (Prim.Vec n a) + -> SingleIntegralType i + -> exp (Prim.Vec n a) + -> exp i + -> PreSmartExp acc exp a - VecUnpack :: KnownNat n - => VecR n s tup - -> exp (Vec n s) - -> PreSmartExp acc exp tup + Insert :: ScalarType (Prim.Vec n a) + -> SingleIntegralType i + -> exp (Prim.Vec n a) + -> exp i + -> exp a + -> PreSmartExp acc exp (Prim.Vec n a) + + Shuffle :: ScalarType (Prim.Vec m a) + -> SingleIntegralType i + -> exp (Prim.Vec n a) + -> exp (Prim.Vec n a) + -> exp (Prim.Vec m i) + -> PreSmartExp acc exp (Prim.Vec m a) + + Select :: ScalarType (Prim.Vec n a) + -> exp (Prim.Vec n PrimBool) + -> exp (Prim.Vec n a) + -> exp (Prim.Vec n a) + -> PreSmartExp acc exp (Prim.Vec n a) ToIndex :: ShapeR sh -> exp sh -> exp sh - -> PreSmartExp acc exp Int + -> PreSmartExp acc exp INT FromIndex :: ShapeR sh -> exp sh - -> exp Int + -> exp INT -> PreSmartExp acc exp sh Case :: exp a @@ -545,9 +558,6 @@ data PreSmartExp acc exp t where -> exp t -> PreSmartExp acc exp t - PrimConst :: PrimConst t - -> PreSmartExp acc exp t - PrimApp :: PrimFun (a -> r) -> exp a -> PreSmartExp acc exp r @@ -559,7 +569,7 @@ data PreSmartExp acc exp t where LinearIndex :: TypeR t -> acc (Array sh t) - -> exp Int + -> exp INT -> PreSmartExp acc exp t Shape :: ShapeR sh @@ -568,7 +578,7 @@ data PreSmartExp acc exp t where ShapeSize :: ShapeR sh -> exp sh - -> PreSmartExp acc exp Int + -> PreSmartExp acc exp INT Foreign :: Foreign asm => TypeR y @@ -580,8 +590,7 @@ data PreSmartExp acc exp t where Undef :: ScalarType t -> PreSmartExp acc exp t - Coerce :: BitSizeEq a b - => ScalarType a + Bitcast :: ScalarType a -> ScalarType b -> exp a -> PreSmartExp acc exp b @@ -807,30 +816,32 @@ instance HasArraysR acc => HasArraysR (PreSmartAcc acc exp) where PairIdxRight -> t2 Aprj _ _ -> error "Ejector seat? You're joking!" Atrace _ _ a -> arraysR a + Acoerce _ bR a -> let ArrayR shR _ = arrayR a + in TupRsingle $ ArrayR shR bR Use repr _ -> TupRsingle repr - Unit tp _ -> TupRsingle $ ArrayR ShapeRz $ tp + Unit aR _ -> TupRsingle $ ArrayR ShapeRz aR Generate repr _ _ -> TupRsingle repr - Reshape shr _ a -> let ArrayR _ tp = arrayR a - in TupRsingle $ ArrayR shr tp - Replicate si _ a -> let ArrayR _ tp = arrayR a - in TupRsingle $ ArrayR (sliceDomainR si) tp - Slice si a _ -> let ArrayR _ tp = arrayR a - in TupRsingle $ ArrayR (sliceShapeR si) tp - Map _ tp _ a -> let ArrayR shr _ = arrayR a - in TupRsingle $ ArrayR shr tp - ZipWith _ _ tp _ a _ -> let ArrayR shr _ = arrayR a - in TupRsingle $ ArrayR shr tp - Fold _ _ _ a -> let ArrayR (ShapeRsnoc shr) tp = arrayR a - in TupRsingle (ArrayR shr tp) + Reshape shr _ a -> let ArrayR _ aR = arrayR a + in TupRsingle $ ArrayR shr aR + Replicate si _ a -> let ArrayR _ aR = arrayR a + in TupRsingle $ ArrayR (sliceDomainR si) aR + Slice si a _ -> let ArrayR _ aR = arrayR a + in TupRsingle $ ArrayR (sliceShapeR si) aR + Map _ aR _ a -> let ArrayR shr _ = arrayR a + in TupRsingle $ ArrayR shr aR + ZipWith _ _ aR _ a _ -> let ArrayR shr _ = arrayR a + in TupRsingle $ ArrayR shr aR + Fold _ _ _ a -> let ArrayR (ShapeRsnoc shr) aR = arrayR a + in TupRsingle (ArrayR shr aR) FoldSeg _ _ _ _ a _ -> arraysR a Scan _ _ _ _ a -> arraysR a - Scan' _ _ _ _ a -> let repr@(ArrayR (ShapeRsnoc shr) tp) = arrayR a - in TupRsingle repr `TupRpair` TupRsingle (ArrayR shr tp) + Scan' _ _ _ _ a -> let repr@(ArrayR (ShapeRsnoc shr) aR) = arrayR a + in TupRsingle repr `TupRpair` TupRsingle (ArrayR shr aR) Permute _ _ a _ _ -> arraysR a - Backpermute shr _ _ a -> let ArrayR _ tp = arrayR a - in TupRsingle (ArrayR shr tp) - Stencil s tp _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) tp - Stencil2 s _ tp _ _ _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) tp + Backpermute shr _ _ a -> let ArrayR _ aR = arrayR a + in TupRsingle (ArrayR shr aR) + Stencil s aR _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) aR + Stencil2 s _ aR _ _ _ _ _ -> TupRsingle $ ArrayR (stencilShapeR s) aR class HasTypeR f where @@ -841,9 +852,9 @@ instance HasTypeR SmartExp where instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where typeR = \case - Tag tp _ -> tp + Tag tR _ -> tR Match _ e -> typeR e - Const tp _ -> TupRsingle tp + Const tR _ -> TupRsingle tR Nil -> TupRunit Pair e1 e2 -> typeR e1 `TupRpair` typeR e2 Prj idx e @@ -851,23 +862,44 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where PairIdxLeft -> t1 PairIdxRight -> t2 Prj _ _ -> error "I never joke about my work" - VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR - VecUnpack vecR _ -> vecRtuple vecR - ToIndex _ _ _ -> TupRsingle scalarTypeInt + Extract vR _ _ _ -> TupRsingle (scalar vR) + where + scalar :: ScalarType (Prim.Vec n a) -> ScalarType a + scalar (NumScalarType t) = NumScalarType (num t) + scalar (BitScalarType t) = BitScalarType (bit t) + + bit :: BitType (Prim.Vec n a) -> BitType a + bit TypeMask{} = TypeBit + + num :: NumType (Prim.Vec n a) -> NumType a + num (IntegralNumType t) = IntegralNumType (integral t) + num (FloatingNumType t) = FloatingNumType (floating t) + + integral :: IntegralType (Prim.Vec n a) -> IntegralType a + integral (SingleIntegralType t) = case t of + integral (VectorIntegralType _ t) = SingleIntegralType t + + floating :: FloatingType (Prim.Vec n a) -> FloatingType a + floating (SingleFloatingType t) = case t of + floating (VectorFloatingType _ t) = SingleFloatingType t + -- + Insert t _ _ _ _ -> TupRsingle t + Shuffle t _ _ _ _ -> TupRsingle t + Select t _ _ _ -> TupRsingle t + ToIndex _ _ _ -> TupRsingle (scalarType @INT) FromIndex shr _ _ -> shapeType shr Case _ ((_,c):_) -> typeR c Case{} -> internalError "encountered empty case" Cond _ e _ -> typeR e While t _ _ _ -> t - PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c PrimApp f _ -> snd $ primFunType f - Index tp _ _ -> tp - LinearIndex tp _ _ -> tp + Index tR _ _ -> tR + LinearIndex tR _ _ -> tR Shape shr _ -> shapeType shr - ShapeSize _ _ -> TupRsingle scalarTypeInt - Foreign tp _ _ _ -> tp - Undef tp -> TupRsingle tp - Coerce _ tp _ -> TupRsingle tp + ShapeSize _ _ -> TupRsingle (scalarType @INT) + Foreign tR _ _ _ -> tR + Undef tR -> TupRsingle tR + Bitcast _ tR _ -> TupRsingle tR -- Smart constructors @@ -890,9 +922,10 @@ constant = Exp . go (eltR @e) . fromElt where go :: HasCallStack => TypeR t -> t -> SmartExp t go TupRunit () = SmartExp $ Nil - go (TupRsingle tp) c = SmartExp $ Const tp c + go (TupRsingle tR) c = SmartExp $ Const tR c go (TupRpair t1 t2) (c1, c2) = SmartExp $ go t1 c1 `Pair` go t2 c2 + -- | 'undef' can be used anywhere a constant is expected, and indicates that the -- consumer of the value can receive an unspecified bit pattern. -- @@ -924,6 +957,7 @@ undef = Exp $ go $ eltR @e go (TupRsingle t) = SmartExp $ Undef t go (TupRpair t1 t2) = SmartExp $ go t1 `Pair` go t2 + -- | Get the innermost dimension of a shape. -- -- The innermost dimension (right-most component of the shape) is the index of @@ -943,249 +977,247 @@ indexTail :: (Elt sh, Elt a) => Exp (sh :. a) -> Exp sh indexTail (Exp x) = mkExp $ Prj PairIdxLeft x --- Smart constructor for constants +-- | Fill all lanes of a SIMD vector with the given value -- - -mkMinBound :: (Elt t, IsBounded (EltR t)) => Exp t -mkMinBound = mkExp $ PrimConst (PrimMinBound boundedType) - -mkMaxBound :: (Elt t, IsBounded (EltR t)) => Exp t -mkMaxBound = mkExp $ PrimConst (PrimMaxBound boundedType) - -mkPi :: (Elt r, IsFloating (EltR r)) => Exp r -mkPi = mkExp $ PrimConst (PrimPi floatingType) - - --- Smart constructors for primitive applications +-- @since 1.4.0.0 -- +splat :: (SIMD n a, Elt a) => Exp a -> Exp (Vec n a) +splat x = pack (repeat x) --- Operators from Floating - -mkSin :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkSin = mkPrimUnary $ PrimSin floatingType - -mkCos :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkCos = mkPrimUnary $ PrimCos floatingType - -mkTan :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkTan = mkPrimUnary $ PrimTan floatingType - -mkAsin :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkAsin = mkPrimUnary $ PrimAsin floatingType - -mkAcos :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkAcos = mkPrimUnary $ PrimAcos floatingType - -mkAtan :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkAtan = mkPrimUnary $ PrimAtan floatingType - -mkSinh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkSinh = mkPrimUnary $ PrimSinh floatingType - -mkCosh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkCosh = mkPrimUnary $ PrimCosh floatingType - -mkTanh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkTanh = mkPrimUnary $ PrimTanh floatingType - -mkAsinh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkAsinh = mkPrimUnary $ PrimAsinh floatingType - -mkAcosh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkAcosh = mkPrimUnary $ PrimAcosh floatingType - -mkAtanh :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkAtanh = mkPrimUnary $ PrimAtanh floatingType - -mkExpFloating :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkExpFloating = mkPrimUnary $ PrimExpFloating floatingType - -mkSqrt :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkSqrt = mkPrimUnary $ PrimSqrt floatingType - -mkLog :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkLog = mkPrimUnary $ PrimLog floatingType - -mkFPow :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t -mkFPow = mkPrimBinary $ PrimFPow floatingType - -mkLogBase :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t -mkLogBase = mkPrimBinary $ PrimLogBase floatingType - --- Operators from Num - -mkAdd :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t -mkAdd = mkPrimBinary $ PrimAdd numType - -mkSub :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t -mkSub = mkPrimBinary $ PrimSub numType - -mkMul :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -> Exp t -mkMul = mkPrimBinary $ PrimMul numType - -mkNeg :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -mkNeg = mkPrimUnary $ PrimNeg numType - -mkAbs :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -mkAbs = mkPrimUnary $ PrimAbs numType - -mkSig :: (Elt t, IsNum (EltR t)) => Exp t -> Exp t -mkSig = mkPrimUnary $ PrimSig numType - --- Operators from Integral - -mkQuot :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t -mkQuot = mkPrimBinary $ PrimQuot integralType - -mkRem :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t -mkRem = mkPrimBinary $ PrimRem integralType - -mkQuotRem :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> (Exp t, Exp t) -mkQuotRem (Exp x) (Exp y) = - let pair = SmartExp $ PrimQuotRem integralType `PrimApp` SmartExp (Pair x y) - in (mkExp $ Prj PairIdxLeft pair, mkExp $ Prj PairIdxRight pair) - -mkIDiv :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t -mkIDiv = mkPrimBinary $ PrimIDiv integralType - -mkMod :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t -mkMod = mkPrimBinary $ PrimMod integralType - -mkDivMod :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> (Exp t, Exp t) -mkDivMod (Exp x) (Exp y) = - let pair = SmartExp $ PrimDivMod integralType `PrimApp` SmartExp (Pair x y) - in (mkExp $ Prj PairIdxLeft pair, mkExp $ Prj PairIdxRight pair) - --- Operators from Bits and FiniteBits - -mkBAnd :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t -mkBAnd = mkPrimBinary $ PrimBAnd integralType - -mkBOr :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t -mkBOr = mkPrimBinary $ PrimBOr integralType - -mkBXor :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -> Exp t -mkBXor = mkPrimBinary $ PrimBXor integralType - -mkBNot :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp t -mkBNot = mkPrimUnary $ PrimBNot integralType - -mkBShiftL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -mkBShiftL = mkPrimBinary $ PrimBShiftL integralType - -mkBShiftR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -mkBShiftR = mkPrimBinary $ PrimBShiftR integralType - -mkBRotateL :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -mkBRotateL = mkPrimBinary $ PrimBRotateL integralType - -mkBRotateR :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -> Exp t -mkBRotateR = mkPrimBinary $ PrimBRotateR integralType - -mkPopCount :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -mkPopCount = mkPrimUnary $ PrimPopCount integralType - -mkCountLeadingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -mkCountLeadingZeros = mkPrimUnary $ PrimCountLeadingZeros integralType - -mkCountTrailingZeros :: (Elt t, IsIntegral (EltR t)) => Exp t -> Exp Int -mkCountTrailingZeros = mkPrimUnary $ PrimCountTrailingZeros integralType - - --- Operators from Fractional - -mkFDiv :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t -mkFDiv = mkPrimBinary $ PrimFDiv floatingType - -mkRecip :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -mkRecip = mkPrimUnary $ PrimRecip floatingType - --- Operators from RealFrac - -mkTruncate :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkTruncate = mkPrimUnary $ PrimTruncate floatingType integralType - -mkRound :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkRound = mkPrimUnary $ PrimRound floatingType integralType - -mkFloor :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkFloor = mkPrimUnary $ PrimFloor floatingType integralType - -mkCeiling :: (Elt a, Elt b, IsFloating (EltR a), IsIntegral (EltR b)) => Exp a -> Exp b -mkCeiling = mkPrimUnary $ PrimCeiling floatingType integralType - --- Operators from RealFloat - -mkAtan2 :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp t -> Exp t -mkAtan2 = mkPrimBinary $ PrimAtan2 floatingType - -mkIsNaN :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp Bool -mkIsNaN = mkPrimUnaryBool $ PrimIsNaN floatingType - -mkIsInfinite :: (Elt t, IsFloating (EltR t)) => Exp t -> Exp Bool -mkIsInfinite = mkPrimUnaryBool $ PrimIsInfinite floatingType - --- FIXME: add missing operations from Floating, RealFrac & RealFloat - --- Relational and equality operators - -mkLt :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkLt = mkPrimBinaryBool $ PrimLt singleType - -mkGt :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkGt = mkPrimBinaryBool $ PrimGt singleType - -mkLtEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkLtEq = mkPrimBinaryBool $ PrimLtEq singleType - -mkGtEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkGtEq = mkPrimBinaryBool $ PrimGtEq singleType - -mkEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkEq = mkPrimBinaryBool $ PrimEq singleType - -mkNEq :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp Bool -mkNEq = mkPrimBinaryBool $ PrimNEq singleType - -mkMax :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp t -mkMax = mkPrimBinary $ PrimMax singleType - -mkMin :: (Elt t, IsSingle (EltR t)) => Exp t -> Exp t -> Exp t -mkMin = mkPrimBinary $ PrimMin singleType - --- Logical operators +-- | Pack scalar expressions into a single SIMD vector +-- +-- @since 1.4.0.0 +-- +pack :: forall n a. (SIMD n a, Elt a) => [Exp a] -> Exp (Vec n a) +pack xs = + let go :: Word8 -> [Exp a] -> Exp (Vec n a) -> Exp (Vec n a) + go _ [] vec = vec + go i (v:vs) vec = go (i+1) vs (insert vec (constant i) v) + -- + n = fromIntegral (natVal' (proxy# :: Proxy# n)) + in + go 0 (take n xs) undef + +-- | Unpack the lanes of a SIMD vector into scalar elements +-- +-- @since 1.4.0.0 +-- +unpack :: forall n a. (SIMD n a, Elt a) => Exp (Vec n a) -> [Exp a] +unpack v = + let n = fromIntegral (natVal' (proxy# :: Proxy# n)) :: Word8 + in map (extract v . constant) [0 .. n-1] -mkLAnd :: Exp Bool -> Exp Bool -> Exp Bool -mkLAnd (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLAnd (SmartExp $ Pair x y)) `Pair` SmartExp Nil +-- | Extract a single scalar element from the given SIMD vector at the +-- specified index +-- +-- @since 1.4.0.0 +-- +extract :: forall n a i. (Elt a, SIMD n a, IsSingleIntegral (EltR i)) + => Exp (Vec n a) + -> Exp i + -> Exp a +extract (Exp v) (Exp i) = Exp $ mkExtract (vecR @n @a) (eltR @a) singleIntegralType v i + +mkExtract :: TypeR v -> TypeR e -> SingleIntegralType i -> SmartExp v -> SmartExp i -> SmartExp e +mkExtract _vR _eR iR _v i = go _vR _eR _v where - x = SmartExp $ Prj PairIdxLeft a - y = SmartExp $ Prj PairIdxLeft b - -mkLOr :: Exp Bool -> Exp Bool -> Exp Bool -mkLOr (Exp a) (Exp b) = mkExp $ SmartExp (PrimApp PrimLOr (SmartExp $ Pair x y)) `Pair` SmartExp Nil + go :: TypeR v -> TypeR e -> SmartExp v -> SmartExp e + go TupRunit TupRunit _ = SmartExp Nil + go (TupRsingle vR) (TupRsingle eR) v = scalar vR eR v + go (TupRpair vR1 vR2) (TupRpair eR1 eR2) v = + let (v1, v2) = unPair v + in SmartExp $ go vR1 eR1 v1 `Pair` go vR2 eR2 v2 + go _ _ _ = error "impossible" + + scalar :: ScalarType v -> ScalarType e -> SmartExp v -> SmartExp e + scalar (NumScalarType vR) (NumScalarType eR) = num vR eR + scalar (BitScalarType vR) (BitScalarType eR) = bit vR eR + scalar _ _ = error "impossible" + + bit :: BitType v -> BitType e -> SmartExp v -> SmartExp e + bit TypeMask{} TypeBit v = SmartExp $ Extract scalarType iR v i + bit _ _ _ = error "impossible" + + num :: NumType v -> NumType e -> SmartExp v -> SmartExp e + num (IntegralNumType vR) (IntegralNumType eR) = integral vR eR + num (FloatingNumType vR) (FloatingNumType eR) = floating vR eR + num _ _ = error "impossible" + + integral :: IntegralType v -> IntegralType e -> SmartExp v -> SmartExp e + integral (VectorIntegralType n vR) (SingleIntegralType tR) v + | Just Refl <- matchSingleIntegralType vR tR + = SmartExp $ Extract (NumScalarType (IntegralNumType (VectorIntegralType n vR))) iR v i + integral _ _ _ = error "impossible" + + floating :: FloatingType v -> FloatingType e -> SmartExp v -> SmartExp e + floating (VectorFloatingType n vR) (SingleFloatingType tR) v + | Just Refl <- matchSingleFloatingType vR tR + = SmartExp $ Extract (NumScalarType (FloatingNumType (VectorFloatingType n vR))) iR v i + floating _ _ _ = error "impossible" + + +-- | Insert a scalar element into the given SIMD vector at the specified +-- index +-- +-- @since 1.4.0.0 +-- +insert :: forall n a i. (Elt a, SIMD n a, IsSingleIntegral (EltR i)) + => Exp (Vec n a) + -> Exp i + -> Exp a + -> Exp (Vec n a) +insert (Exp v) (Exp i) (Exp x) = Exp $ mkInsert (vecR @n @a) (eltR @a) singleIntegralType v i x + +mkInsert :: TypeR v -> TypeR e -> SingleIntegralType i -> SmartExp v -> SmartExp i -> SmartExp e -> SmartExp v +mkInsert _vR _eR iR _v i _x = go _vR _eR _v _x where - x = SmartExp $ Prj PairIdxLeft a - y = SmartExp $ Prj PairIdxLeft b - -mkLNot :: Exp Bool -> Exp Bool -mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil + go :: TypeR v -> TypeR e -> SmartExp v -> SmartExp e -> SmartExp v + go TupRunit TupRunit _ _ = SmartExp Nil + go (TupRsingle vR) (TupRsingle eR) v e = scalar vR eR v e + go (TupRpair vR1 vR2) (TupRpair eR1 eR2) v e = + let (v1, v2) = unPair v + (e1, e2) = unPair e + in SmartExp $ go vR1 eR1 v1 e1 `Pair` go vR2 eR2 v2 e2 + go _ _ _ _ = error "impossible" + + scalar :: ScalarType v -> ScalarType e -> SmartExp v -> SmartExp e -> SmartExp v + scalar (NumScalarType vR) (NumScalarType eR) = num vR eR + scalar (BitScalarType vR) (BitScalarType eR) = bit vR eR + scalar _ _ = error "impossible" + + bit :: BitType v -> BitType e -> SmartExp v -> SmartExp e -> SmartExp v + bit TypeMask{} TypeBit v = SmartExp . Insert scalarType iR v i + bit _ _ _ = error "impossible" + + num :: NumType v -> NumType e -> SmartExp v -> SmartExp e -> SmartExp v + num (IntegralNumType vR) (IntegralNumType eR) = integral vR eR + num (FloatingNumType vR) (FloatingNumType eR) = floating vR eR + num _ _ = error "impossible" + + integral :: IntegralType v -> IntegralType e -> SmartExp v -> SmartExp e -> SmartExp v + integral (VectorIntegralType n vR) (SingleIntegralType tR) v + | Just Refl <- matchSingleIntegralType vR tR + = SmartExp . Insert (NumScalarType (IntegralNumType (VectorIntegralType n vR))) iR v i + integral _ _ _ = error "impossible" + + floating :: FloatingType v -> FloatingType e -> SmartExp v -> SmartExp e -> SmartExp v + floating (VectorFloatingType n vR) (SingleFloatingType tR) v + | Just Refl <- matchSingleFloatingType vR tR + = SmartExp . Insert (NumScalarType (FloatingNumType (VectorFloatingType n vR))) iR v i + floating _ _ _ = error "impossible" + + +-- | Construct a permutation of elements from two input vectors +-- +-- The elements of the two input vectors are concatenated and numbered from +-- zero left-to-right. For each element in the result vector, the shuffle +-- mask selects the corresponding element from this concatenated vector to +-- copy to the result. +-- +-- @since 1.4.0.0 +-- +shuffle :: forall m n a i. (SIMD n a, SIMD m a, SIMD m i, IsSingleIntegral (EltR i)) + => Exp (Vec n a) + -> Exp (Vec n a) + -> Exp (Vec m i) + -> Exp (Vec m a) +shuffle (Exp xs) (Exp ys) (Exp i) = Exp $ go (vecR @n @a) (vecR @m @a) xs ys where - x = SmartExp $ Prj PairIdxLeft a - --- Numeric conversions - -mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b -mkFromIntegral = mkPrimUnary $ PrimFromIntegral integralType numType - -mkToFloating :: (Elt a, Elt b, IsNum (EltR a), IsFloating (EltR b)) => Exp a -> Exp b -mkToFloating = mkPrimUnary $ PrimToFloating numType floatingType - --- Other conversions - --- NOTE: Restricted to scalar types with a type-level BitSizeEq constraint to --- make this version "safe" -mkBitcast :: forall b a. (Elt a, Elt b, IsScalar (EltR a), IsScalar (EltR b), BitSizeEq (EltR a) (EltR b)) => Exp a -> Exp b -mkBitcast (Exp a) = mkExp $ Coerce (scalarType @(EltR a)) (scalarType @(EltR b)) a + go :: TypeR t -> TypeR r -> SmartExp t -> SmartExp t -> SmartExp r + go TupRunit TupRunit _ _ = SmartExp Nil + go (TupRsingle vR) (TupRsingle rR) x y = scalar vR rR x y + go (TupRpair vR1 vR2) (TupRpair rR1 rR2) x y = + let (x1, x2) = unPair x + (y1, y2) = unPair y + in SmartExp $ go vR1 rR1 x1 y1 `Pair` go vR2 rR2 x2 y2 + go _ _ _ _ = error "impossible" + + scalar :: ScalarType t -> ScalarType r -> SmartExp t -> SmartExp t -> SmartExp r + scalar (NumScalarType vR) (NumScalarType rR) = num vR rR + scalar (BitScalarType vR) (BitScalarType rR) = bit vR rR + scalar _ _ = error "impossible" + + bit :: BitType t -> BitType r -> SmartExp t -> SmartExp t -> SmartExp r + bit (TypeMask _) (TypeMask m) x y + | TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType m' iR))) <- vecR @m @i + , Just Refl <- sameNat' m m' + = SmartExp $ Shuffle (BitScalarType (TypeMask m)) iR x y i + bit _ _ _ _ = error "impossible" + + num :: NumType t -> NumType r -> SmartExp t -> SmartExp t -> SmartExp r + num (IntegralNumType vR) (IntegralNumType rR) = integral vR rR + num (FloatingNumType vR) (FloatingNumType rR) = floating vR rR + num _ _ = error "impossible" + + integral :: IntegralType t -> IntegralType r -> SmartExp t -> SmartExp t -> SmartExp r + integral (VectorIntegralType _ vR) (VectorIntegralType m rR) x y + | TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType m' iR))) <- vecR @m @i + , Just Refl <- matchSingleIntegralType vR rR + , Just Refl <- sameNat' m m' + = SmartExp $ Shuffle (NumScalarType (IntegralNumType (VectorIntegralType m vR))) iR x y i + integral _ _ _ _ = error "impossible" + + floating :: FloatingType t -> FloatingType r -> SmartExp t -> SmartExp t -> SmartExp r + floating (VectorFloatingType _ vR) (VectorFloatingType m rR) x y + | TupRsingle (NumScalarType (IntegralNumType (VectorIntegralType m' iR))) <- vecR @m @i + , Just Refl <- matchSingleFloatingType vR rR + , Just Refl <- sameNat' m m' + = SmartExp $ Shuffle (NumScalarType (FloatingNumType (VectorFloatingType m vR))) iR x y i + floating _ _ _ _ = error "impossible" + + +-- | Choose one value based on a condition, without branching. This is strict in +-- both arguments. +-- +-- @since 1.4.0.0 +-- +select :: forall n a. SIMD n a + => Exp (Vec n Bool) + -> Exp (Vec n a) + -> Exp (Vec n a) + -> Exp (Vec n a) +select (Exp mask) (Exp tt) (Exp ff) = Exp $ go (vecR @n @a) tt ff + where + go :: TypeR t -> SmartExp t -> SmartExp t -> SmartExp t + go TupRunit _ _ = SmartExp Nil + go (TupRsingle vR) t f = scalar vR t f + go (TupRpair vR1 vR2) t f = + let (t1, t2) = unPair t + (f1, f2) = unPair f + in SmartExp $ go vR1 t1 f1 `Pair` go vR2 t2 f2 + + scalar :: ScalarType t -> SmartExp t -> SmartExp t -> SmartExp t + scalar (NumScalarType vR) = num vR + scalar (BitScalarType vR) = bit vR + + bit :: BitType t -> SmartExp t -> SmartExp t -> SmartExp t + bit (TypeMask n) + | Just Refl <- sameNat' n (proxy# :: Proxy# n) + = SmartExp $$ Select (BitScalarType (TypeMask n)) mask + bit _ = error "impossible" + + num :: NumType t -> SmartExp t -> SmartExp t -> SmartExp t + num (IntegralNumType vR) = integral vR + num (FloatingNumType vR) = floating vR + + integral :: IntegralType t -> SmartExp t -> SmartExp t -> SmartExp t + integral (VectorIntegralType n t) + | Just Refl <- sameNat' n (proxy# :: Proxy# n) + = SmartExp $$ Select (NumScalarType (IntegralNumType (VectorIntegralType n t))) mask + integral _ = error "impossible" + + floating :: FloatingType t -> SmartExp t -> SmartExp t -> SmartExp t + floating (VectorFloatingType n t) + | Just Refl <- sameNat' n (proxy# :: Proxy# n) + = SmartExp $$ Select (NumScalarType (FloatingNumType (VectorFloatingType n t))) mask + floating _ = error "impossible" + + +-- Coercions between types +-- ----------------------- + +mkBitcast :: forall b a. (IsScalar (EltR a), IsScalar (EltR b), BitSizeEq (EltR a) (EltR b)) => Exp a -> Exp b +mkBitcast (Exp a) = mkExp $ Bitcast (scalarType @(EltR a)) (scalarType @(EltR b)) a mkCoerce :: Coerce (EltR a) (EltR b) => Exp a -> Exp b mkCoerce (Exp a) = Exp $ mkCoerce' a @@ -1194,7 +1226,7 @@ class Coerce a b where mkCoerce' :: SmartExp a -> SmartExp b instance {-# OVERLAPS #-} (IsScalar a, IsScalar b, BitSizeEq a b) => Coerce a b where - mkCoerce' = SmartExp . Coerce (scalarType @a) (scalarType @b) + mkCoerce' = SmartExp . Bitcast (scalarType @a) (scalarType @b) instance (Coerce a1 b1, Coerce a2 b2) => Coerce (a1, a2) (b1, b2) where mkCoerce' a = SmartExp $ Pair (mkCoerce' $ SmartExp $ Prj PairIdxLeft a) (mkCoerce' $ SmartExp $ Prj PairIdxRight a) @@ -1215,6 +1247,89 @@ instance Coerce a (a, ()) where mkCoerce' a = SmartExp (Pair a (SmartExp Nil)) +mkAcoerce + :: forall b a sh. (Acoerce (EltR a) (EltR b)) + => Acc (Sugar.Array (sh :. Int) a) + -> Acc (Sugar.Array (sh :. Int) b) +mkAcoerce (Acc a) = + let ArrayR _ aR = arrayR a + (bR, sz) = mkAcoerce' @(EltR a) aR + in Acc $ SmartAcc (Acoerce sz bR a) + +class Acoerce a b where + mkAcoerce' :: TypeR a -> (TypeR b, RescaleFactor) + +instance {-# OVERLAPS #-} (IsScalar a, IsScalar b) => Acoerce a b where + mkAcoerce' _ = + let ta = scalarType @a + tb = scalarType @b + sa = scalar ta + sb = scalar tb + sz = case compare sa sb of + EQ -> 0 -- TLM: reuse this value for something else? rescale of ±1 achieves the same thing + GT -> sa `quot` sb + LT -> negate $ sb `quot` sa + -- + scalar :: ScalarType t -> RescaleFactor + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType t -> RescaleFactor + bit TypeBit = 1 + bit (TypeMask n) = fromInteger (natVal' n) + + num :: NumType t -> RescaleFactor + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType t -> RescaleFactor + integral (VectorIntegralType n t) = fromInteger (natVal' n) * integral (SingleIntegralType t) + integral (SingleIntegralType t) = case t of + TypeInt n -> fromIntegral n + TypeWord n -> fromIntegral n + + floating :: FloatingType t -> RescaleFactor + floating (VectorFloatingType n t) = fromInteger (natVal' n) * floating (SingleFloatingType t) + floating (SingleFloatingType t) = case t of + TypeFloat16 -> 16 + TypeFloat32 -> 32 + TypeFloat64 -> 64 + TypeFloat128 -> 128 + in + (TupRsingle tb, sz) + +-- TODO: Make this a compile time error. This should be possible with type +-- families, but GHC seems to have problems normalising the BitSize of 'Vec's +-- (and the GHC.TypeLits.Normalise package unfortunately does not seem to help). +-- +instance (Acoerce a1 b1, Acoerce a2 b2) => Acoerce (a1, a2) (b1, b2) where + mkAcoerce' (TupRpair a1 a2) = + let (b1, s1) = mkAcoerce' a1 + (b2, s2) = mkAcoerce' a2 + in + if s1 == s2 + then (TupRpair b1 b2, s1) + else error $ formatToString ("Could not coerce type `" % formatTypeR % "' to `" % formatTypeR % "'") + (TupRpair a1 a2) (TupRpair b1 b2) + mkAcoerce' _ = error "impossible" + +instance Acoerce ((), a) a where + mkAcoerce' (TupRpair TupRunit a) = (a, 0) + mkAcoerce' _ = error "impossible" + +instance Acoerce (a, ()) a where + mkAcoerce' (TupRpair a TupRunit) = (a, 0) + mkAcoerce' _ = error "impossible" + +instance Acoerce a ((), a) where + mkAcoerce' a = (TupRpair TupRunit a, 0) + +instance Acoerce a (a, ()) where + mkAcoerce' a = (TupRpair a TupRunit, 0) + +instance Acoerce a a where + mkAcoerce' a = (a, 0) + -- Auxiliary functions -- -------------------- @@ -1247,24 +1362,18 @@ mkExp = Exp . SmartExp unExp :: Exp e -> SmartExp (EltR e) unExp (Exp e) = e -unExpFunction :: (Elt a, Elt b) => (Exp a -> Exp b) -> SmartExp (EltR a) -> SmartExp (EltR b) +unExpFunction :: (Exp a -> Exp b) -> SmartExp (EltR a) -> SmartExp (EltR b) unExpFunction f = unExp . f . Exp -unExpBinaryFunction :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> Exp c) -> SmartExp (EltR a) -> SmartExp (EltR b) -> SmartExp (EltR c) +unExpBinaryFunction :: (Exp a -> Exp b -> Exp c) -> SmartExp (EltR a) -> SmartExp (EltR b) -> SmartExp (EltR c) unExpBinaryFunction f a b = unExp $ f (Exp a) (Exp b) -mkPrimUnary :: (Elt a, Elt b) => PrimFun (EltR a -> EltR b) -> Exp a -> Exp b +mkPrimUnary :: PrimFun (EltR a -> EltR b) -> Exp a -> Exp b mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a -mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c +mkPrimBinary :: PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) -mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool -mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary - -mkPrimBinaryBool :: (Elt a, Elt b) => PrimFun ((EltR a, EltR b) -> PrimBool) -> Exp a -> Exp b -> Exp Bool -mkPrimBinaryBool = mkCoerce @PrimBool $$$ mkPrimBinary - unPair :: SmartExp (a, b) -> (SmartExp a, SmartExp b) unPair e = (SmartExp $ Prj PairIdxLeft e, SmartExp $ Prj PairIdxRight e) @@ -1339,6 +1448,7 @@ formatPreAccOp = later $ \case Stencil{} -> "Stencil" Stencil2{} -> "Stencil2" Aforeign{} -> "Aforeign" + Acoerce{} -> "Acoerce" formatPreExpOp :: Format r (PreSmartExp acc exp t -> r) formatPreExpOp = later $ \case @@ -1349,19 +1459,20 @@ formatPreExpOp = later $ \case Nil{} -> "Nil" Pair{} -> "Pair" Prj{} -> "Prj" - VecPack{} -> "VecPack" - VecUnpack{} -> "VecUnpack" + Extract{} -> "Extract" + Insert{} -> "Insert" + Shuffle{} -> "Shuffle" + Select{} -> "Select" ToIndex{} -> "ToIndex" FromIndex{} -> "FromIndex" Case{} -> "Case" Cond{} -> "Cond" While{} -> "While" - PrimConst{} -> "PrimConst" PrimApp{} -> "PrimApp" Index{} -> "Index" LinearIndex{} -> "LinearIndex" Shape{} -> "Shape" ShapeSize{} -> "ShapeSize" Foreign{} -> "Foreign" - Coerce{} -> "Coerce" + Bitcast{} -> "Bitcast" diff --git a/src/Data/Array/Accelerate/Sugar/Array.hs b/src/Data/Array/Accelerate/Sugar/Array.hs index 87a490246..5d48c245a 100644 --- a/src/Data/Array/Accelerate/Sugar/Array.hs +++ b/src/Data/Array/Accelerate/Sugar/Array.hs @@ -38,6 +38,7 @@ import GHC.Generics import qualified GHC.Exts as GHC -- $setup +-- >>> import Prelude -- >>> :seti -XOverloadedLists @@ -153,7 +154,7 @@ infixl 9 ! -- infixl 9 !! (!!) :: forall sh e. Elt e => Array sh e -> Int -> e -(!!) (Array arr) i = toElt $ R.linearIndexArray (eltR @e) arr i +(!!) (Array arr) i = toElt $ R.linearIndexArray (eltR @e) arr (fromIntegral i) -- | Create an array from its representation function, applied at each -- index of the array @@ -174,8 +175,8 @@ fromFunctionM sh f = Array <$> R.fromFunctionM (arrayR @sh @e) (fromElt sh) f' -- | Create a vector from the concatenation of the given list of vectors -- -concatVectors :: forall e. Elt e => [Vector e] -> Vector e -concatVectors = toArr . R.concatVectors (eltR @e) . map fromArr +-- concatVectors :: forall e. Elt e => [Vector e] -> Vector e +-- concatVectors = toArr . R.concatVectors (eltR @e) . map fromArr -- | Creates a new, uninitialized Accelerate array -- diff --git a/src/Data/Array/Accelerate/Sugar/Elt.hs b/src/Data/Array/Accelerate/Sugar/Elt.hs index b55158900..a873bd221 100644 --- a/src/Data/Array/Accelerate/Sugar/Elt.hs +++ b/src/Data/Array/Accelerate/Sugar/Elt.hs @@ -32,6 +32,7 @@ import Data.Array.Accelerate.Type import Data.Bits import Data.Char import Data.Kind +import Foreign.C.Types import Language.Haskell.TH.Extra hiding ( Type ) import GHC.Generics @@ -45,8 +46,8 @@ import GHC.Generics -- tuples thereof, stored efficiently in memory as consecutive unpacked -- elements without pointers. It roughly consists of: -- --- * Signed and unsigned integers (8, 16, 32, and 64-bits wide) --- * Floating point numbers (half, single, and double precision) +-- * Signed and unsigned integers (8, 16, 32, 64, and 128-bits wide) +-- * Floating point numbers (IEEE half, single, double, and (optionally) quadruple precision) -- * 'Char' -- * 'Bool' -- * () @@ -163,7 +164,7 @@ instance (GElt a, GElt b) => GElt (a :*: b) where instance (GElt a, GElt b, GSumElt (a :+: b)) => GElt (a :+: b) where type GEltR t (a :+: b) = (TAG, GSumEltR t (a :+: b)) geltR t = TupRpair (TupRsingle scalarType) (gsumEltR @(a :+: b) t) - gtagsR t = uncurry TagRtag <$> gsumTagsR @(a :+: b) 0 t + gtagsR t = uncurry (TagRcon TagWord8) <$> gsumTagsR @(a :+: b) 0 t gfromElt = gsumFromElt 0 gtoElt (k,x) = gsumToElt k x gundef t = (0xff, gsumUndef @(a :+: b) t) @@ -282,11 +283,31 @@ untag (TupRpair ta tb) = TagRpair (untag ta) (untag tb) -- instance Elt () -instance Elt Bool instance Elt Ordering instance Elt a => Elt (Maybe a) instance (Elt a, Elt b) => Elt (Either a b) +instance Elt Bool where + type EltR Bool = Bit + eltR = TupRsingle scalarType + tagsR = [TagRenum TagBit 0, TagRenum TagBit 1] + toElt = unBit + fromElt = Bit + +instance Elt Int where + type EltR Int = INT + eltR = TupRsingle scalarType + tagsR = [TagRsingle scalarType] + toElt = fromIntegral + fromElt = fromIntegral + +instance Elt Word where + type EltR Word = WORD + eltR = TupRsingle scalarType + tagsR = [TagRsingle scalarType] + toElt = fromIntegral + fromElt = fromIntegral + instance Elt Char where type EltR Char = Word32 eltR = TupRsingle scalarType @@ -300,16 +321,16 @@ runQ $ do -- integralTypes :: [Name] integralTypes = - [ ''Int - , ''Int8 + [ ''Int8 , ''Int16 , ''Int32 , ''Int64 - , ''Word + , ''Int128 , ''Word8 , ''Word16 , ''Word32 , ''Word64 + , ''Word128 ] floatingTypes :: [Name] @@ -317,6 +338,7 @@ runQ $ do [ ''Half , ''Float , ''Double + , ''Float128 ] newtypes :: [Name] @@ -358,18 +380,6 @@ runQ $ do in instanceD ctx [t| Elt $res |] [] - -- mkVecElt :: Name -> Integer -> Q [Dec] - -- mkVecElt name n = - -- let t = conT name - -- v = [t| Vec $(litT (numTyLit n)) $t |] - -- in - -- [d| instance Elt $v where - -- type EltR $v = $v - -- eltR = TupRsingle scalarType - -- fromElt = id - -- toElt = id - -- |] - -- ghci> $( stringE . show =<< reify ''CFloat ) -- TyConI (NewtypeD [] Foreign.C.Types.CFloat [] Nothing (NormalC Foreign.C.Types.CFloat [(Bang NoSourceUnpackedness NoSourceStrictness,ConT GHC.Types.Float)]) []) -- @@ -391,6 +401,5 @@ runQ $ do ss <- mapM mkSimple (integralTypes ++ floatingTypes) ns <- mapM mkNewtype newtypes ts <- mapM mkTuple [2..16] - -- vs <- sequence [ mkVecElt t n | t <- integralTypes ++ floatingTypes, n <- [2,3,4,8,16] ] return (concat ss ++ concat ns ++ ts) diff --git a/src/Data/Array/Accelerate/Sugar/Shape.hs b/src/Data/Array/Accelerate/Sugar/Shape.hs index 1ac8bd0c4..a57366075 100644 --- a/src/Data/Array/Accelerate/Sugar/Shape.hs +++ b/src/Data/Array/Accelerate/Sugar/Shape.hs @@ -136,12 +136,12 @@ data Divide sh = Divide -- | Number of dimensions of a /shape/ or /index/ (>= 0) -- rank :: forall sh. Shape sh => Int -rank = R.rank (shapeR @sh) +rank = fromIntegral $ R.rank (shapeR @sh) -- | Total number of elements in an array of the given /shape/ -- size :: forall sh. Shape sh => sh -> Int -size = R.size (shapeR @sh) . fromElt +size = fromIntegral . R.size (shapeR @sh) . fromElt -- | The empty /shape/ -- @@ -165,7 +165,7 @@ toIndex :: forall sh. Shape sh => sh -- ^ Total shape (extent) of the array -> sh -- ^ The argument index -> Int -- ^ Corresponding linear index -toIndex sh ix = R.toIndex (shapeR @sh) (fromElt sh) (fromElt ix) +toIndex sh ix = fromIntegral $ R.toIndex (shapeR @sh) (fromElt sh) (fromElt ix) -- | Inverse of 'toIndex'. -- @@ -173,7 +173,7 @@ fromIndex :: forall sh. Shape sh => sh -- ^ Total shape (extent) of the array -> Int -- ^ The argument index -> sh -- ^ Corresponding multi-dimensional index -fromIndex sh = toElt . R.fromIndex (shapeR @sh) (fromElt sh) +fromIndex sh = toElt . R.fromIndex (shapeR @sh) (fromElt sh) . fromIntegral -- | Iterate through all of the indices of a shape, applying the given -- function at each index. The index space is traversed in row-major order. @@ -210,19 +210,19 @@ shapeToRange ix = -- | Convert a shape to a list of dimensions -- shapeToList :: forall sh. Shape sh => sh -> [Int] -shapeToList = R.shapeToList (shapeR @sh) . fromElt +shapeToList = map fromIntegral . R.shapeToList (shapeR @sh) . fromElt -- | Convert a list of dimensions into a shape. If the list does not -- contain exactly the number of elements as specified by the type of the -- shape: error. -- listToShape :: forall sh. Shape sh => [Int] -> sh -listToShape = toElt . R.listToShape (shapeR @sh) +listToShape = toElt . R.listToShape (shapeR @sh) . map fromIntegral -- | Attempt to convert a list of dimensions into a shape -- listToShape' :: forall sh. Shape sh => [Int] -> Maybe sh -listToShape' = fmap toElt . R.listToShape' (shapeR @sh) +listToShape' = fmap toElt . R.listToShape' (shapeR @sh) . map fromIntegral -- | Nicely format a shape as a string -- diff --git a/src/Data/Array/Accelerate/Sugar/Vec.hs b/src/Data/Array/Accelerate/Sugar/Vec.hs index bc49eeb8b..cee537e7c 100644 --- a/src/Data/Array/Accelerate/Sugar/Vec.hs +++ b/src/Data/Array/Accelerate/Sugar/Vec.hs @@ -1,10 +1,21 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_HADDOCK hide #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Sugar.Vec -- Copyright : [2008..2020] The Accelerate Team @@ -15,26 +26,375 @@ -- Portability : non-portable (GHC extensions) -- -module Data.Array.Accelerate.Sugar.Vec - where +module Data.Array.Accelerate.Sugar.Vec ( + + Vec(..), KnownNat, + V2, V3, V4, V8, V16, + SIMD(..), + +) where -import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Representation.Tag import Data.Array.Accelerate.Representation.Type +import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Type -import Data.Primitive.Types -import Data.Primitive.Vec +import qualified Data.Array.Accelerate.Representation.Vec as R +import qualified Data.Primitive.Vec as Prim + +import Data.Kind +import Prettyprinter +import Language.Haskell.TH.Extra hiding ( Type ) import GHC.TypeLits -import GHC.Prim +import GHC.Generics +import qualified GHC.Exts as GHC + + +-- | SIMD vectors of fixed width +-- +data Vec n a = Vec (VecR n a) + +-- Synonyms for common vector sizes +-- +type V2 = Vec 2 +type V3 = Vec 3 +type V4 = Vec 4 +type V8 = Vec 8 +type V16 = Vec 16 + +instance (Show a, SIMD n a) => Show (Vec n a) where + show = vec . toList + where + vec :: [a] -> String + vec = show + . group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " + . map viaShow + +instance (Eq a, SIMD n a) => Eq (Vec n a) where + Vec x == Vec y = tuple (vecR @n @a) x y + where + tuple :: TypeR v -> v -> v -> Bool + tuple TupRunit () () = True + tuple (TupRpair aR bR) (a1,b1) (a2,b2) = tuple aR a1 a2 && tuple bR b1 b2 + tuple (TupRsingle t) a b = scalar t a b + + scalar :: ScalarType v -> v -> v -> Bool + scalar (NumScalarType t) = num t + scalar (BitScalarType t) = bit t + + bit :: BitType v -> v -> v -> Bool + bit TypeBit = (==) + bit TypeMask{} = (==) + + num :: NumType v -> v -> v -> Bool + num (IntegralNumType t) = integral t + num (FloatingNumType t) = floating t + + integral :: IntegralType v -> v -> v -> Bool + integral (SingleIntegralType t) = case t of + TypeInt8 -> (==) + TypeInt16 -> (==) + TypeInt32 -> (==) + TypeInt64 -> (==) + TypeInt128 -> (==) + TypeWord8 -> (==) + TypeWord16 -> (==) + TypeWord32 -> (==) + TypeWord64 -> (==) + TypeWord128 -> (==) + integral (VectorIntegralType _ t) = case t of + TypeInt8 -> (==) + TypeInt16 -> (==) + TypeInt32 -> (==) + TypeInt64 -> (==) + TypeInt128 -> (==) + TypeWord8 -> (==) + TypeWord16 -> (==) + TypeWord32 -> (==) + TypeWord64 -> (==) + TypeWord128 -> (==) + + floating :: FloatingType v -> v -> v -> Bool + floating (SingleFloatingType t) = case t of + TypeFloat16 -> (==) + TypeFloat32 -> (==) + TypeFloat64 -> (==) + TypeFloat128 -> (==) + floating (VectorFloatingType _ t) = case t of + TypeFloat16 -> (==) + TypeFloat32 -> (==) + TypeFloat64 -> (==) + TypeFloat128 -> (==) + + +instance SIMD n a => GHC.IsList (Vec n a) where + type Item (Vec n a) = a + toList = toList + fromList = fromList + +toList :: forall n a. SIMD n a => Vec n a -> [a] +toList (Vec vs) = map toElt $ R.toList (vecR @n @a) (eltR @a) vs + +fromList :: forall n a. SIMD n a => [a] -> Vec n a +fromList = Vec . R.fromList (vecR @n @a) (eltR @a) . map fromElt + +instance SIMD n a => Elt (Vec n a) where + type EltR (Vec n a) = VecR n a + eltR = vecR @n @a + tagsR = vtagsR @n @a + toElt = Vec + fromElt (Vec a) = a + +-- | The 'SIMD' class characterises the subset of scalar element types from +-- 'Elt' that can be packed into SIMD vectors. +-- +-- @since 1.4.0.0 +-- +class (KnownNat n, Elt a) => SIMD n a where + type VecR n a :: Type + type VecR n a = GVecR () n (Rep a) + + vecR :: TypeR (VecR n a) + vtagsR :: [TagR (VecR n a)] + + default vecR + :: (GVec n (Rep a), VecR n a ~ GVecR () n (Rep a)) + => TypeR (VecR n a) + vecR = gvecR @n @(Rep a) TupRunit + + -- default vtagsR + -- :: (GVec n (Rep a), VecR n a ~ GVecR () n (Rep a)) + -- => [TagR (VecR n a)] + -- vtagsR = gvtagsR @n @(Rep a) TagRunit + vtagsR = [tagOfType (vecR @n @a)] + +class KnownNat n => GVec n (f :: Type -> Type) where + type GVecR t n f + gvecR :: TypeR t -> TypeR (GVecR t n f) + gvtagsR :: TagR t -> [TagR (GVecR t n f)] + +instance KnownNat n => GVec n U1 where + type GVecR t n U1 = t + gvecR t = t + gvtagsR t = [t] + +instance GVec n a => GVec n (M1 i c a) where + type GVecR t n (M1 i c a) = GVecR t n a + gvecR = gvecR @n @a + gvtagsR = gvtagsR @n @a + +instance SIMD n a => GVec n (K1 i a) where + type GVecR t n (K1 i a) = (t, VecR n a) + gvecR t = TupRpair t (vecR @n @a) + gvtagsR t = TagRpair t <$> vtagsR @n @a + +instance (GVec n a, GVec n b) => GVec n (a :*: b) where + type GVecR t n (a :*: b) = GVecR (GVecR t n a) n b + gvecR = gvecR @n @b . gvecR @n @a + gvtagsR = concatMap (gvtagsR @n @b) . gvtagsR @n @a + +instance (GVec n a, GVec n b, GSumVec n (a :+: b)) => GVec n (a :+: b) where + type GVecR t n (a :+: b) = (Prim.Vec n TAG, GSumVecR t n (a :+: b)) + gvecR t = TupRpair (TupRsingle scalarType) (gsumvecR @n @(a :+: b) t) + -- gvtagsR t = let zero = Prim.fromList (replicate (fromInteger (natVal' (proxy# @n))) 0) + -- in uncurry TagRtag <$> gsumvtagsR @n @(a :+: b) zero t + gvtagsR _ = error "TODO: gvtagsR (:+:)" + +class KnownNat n => GSumVec n (f :: Type -> Type) where + type GSumVecR t n f + gsumvecR :: TypeR t -> TypeR (GSumVecR t n f) + -- gsumvtagsR :: Prim.Vec n TAG -> TagR t -> [(Prim.Vec n TAG, TagR (GSumVecR t n f))] + +instance KnownNat n => GSumVec n U1 where + type GSumVecR t n U1 = t + gsumvecR t = t + +instance GSumVec n a => GSumVec n (M1 i c a) where + type GSumVecR t n (M1 i c a) = GSumVecR t n a + gsumvecR = gsumvecR @n @a + +instance (KnownNat n, SIMD n a) => GSumVec n (K1 i a) where + type GSumVecR t n (K1 i a) = (t, VecR n a) + gsumvecR t = TupRpair t (vecR @n @a) + +instance (GVec n a, GVec n b) => GSumVec n (a :*: b) where + type GSumVecR t n (a :*: b) = GVecR t n (a :*: b) + gsumvecR = gvecR @n @(a :*: b) + +instance (GSumVec n a, GSumVec n b) => GSumVec n (a :+: b) where + type GSumVecR t n (a :+: b) = GSumVecR (GSumVecR t n a) n b + gsumvecR = gsumvecR @n @b . gsumvecR @n @a + +tagOfType :: TypeR a -> TagR a +tagOfType TupRunit = TagRunit +tagOfType (TupRpair s t) = TagRpair (tagOfType s) (tagOfType t) +tagOfType (TupRsingle t) = TagRsingle t + +instance KnownNat n => SIMD n Z +instance KnownNat n => SIMD n () +instance KnownNat n => SIMD n Ordering +instance SIMD n a => SIMD n (Maybe a) +instance SIMD n sh => SIMD n (sh :. Int) +instance (SIMD n a, SIMD n b) => SIMD n (Either a b) + +instance KnownNat n => SIMD n Bool where + type VecR n Bool = Prim.Vec n Bit + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + +instance KnownNat n => SIMD n Int where + type VecR n Int = Prim.Vec n (EltR Int) + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + +instance KnownNat n => SIMD n Word where + type VecR n Word = Prim.Vec n (EltR Word) + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + +instance KnownNat n => SIMD n Char where + type VecR n Char = Prim.Vec n (EltR Char) + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + +runQ $ do + let + integralTypes :: [Name] + integralTypes = + [ ''Int8 + , ''Int16 + , ''Int32 + , ''Int64 + , ''Int128 + , ''Word8 + , ''Word16 + , ''Word32 + , ''Word64 + , ''Word128 + ] + + floatingTypes :: [Name] + floatingTypes = + [ ''Half + , ''Float + , ''Double + , ''Float128 + ] + + numTypes :: [Name] + numTypes = integralTypes ++ floatingTypes + + mkPrim :: Name -> Q [Dec] + mkPrim name = + let t = conT name + in + [d| instance KnownNat n => SIMD n $t where + type VecR n $t = Prim.Vec n $t + vecR = TupRsingle scalarType + vtagsR = [TagRsingle scalarType] + |] + + mkTuple :: Int -> Q Dec + mkTuple n = do + w <- newName "w" + let + xs = [ mkName ('x' : show i) | i <- [0 .. n-1] ] + ts = map varT xs + res = tupT ts + ctx = mapM (appT [t| SIMD $(varT w) |]) ts + -- + instanceD ctx [t| SIMD $(varT w) $res |] [] + -- + ps <- mapM mkPrim numTypes + ts <- mapM mkTuple [2..16] + return (concat ps ++ ts) + + +{-- +type NoTypeError = (() :: Constraint) + +type family NoNestedVec (f :: k) :: Constraint where + NoNestedVec t = If (HasNestedVec (Rep t)) + (TypeError (NestedVecError t)) + NoTypeError + +type family HasNestedVec (f :: k -> Type) :: Bool where + HasNestedVec (f :+: g) = HasNestedVec f || HasNestedVec g + HasNestedVec (f :*: g) = HasNestedVec f || HasNestedVec g + HasNestedVec (M1 _ _ a) = HasNestedVec a + HasNestedVec U1 = 'False + -- HasNestedVec (Rec0 Int) = 'False + -- HasNestedVec (Rec0 Int8) = 'False + -- HasNestedVec (Rec0 Int16) = 'False + -- HasNestedVec (Rec0 Int32) = 'False + -- HasNestedVec (Rec0 Int64) = 'False + -- HasNestedVec (Rec0 Int128) = 'False + -- HasNestedVec (Rec0 Word) = 'False + -- HasNestedVec (Rec0 Word8) = 'False + -- HasNestedVec (Rec0 Word16) = 'False + -- HasNestedVec (Rec0 Word32) = 'False + -- HasNestedVec (Rec0 Word64) = 'False + -- HasNestedVec (Rec0 Word128) = 'False + -- HasNestedVec (Rec0 Half) = 'False + -- HasNestedVec (Rec0 Float) = 'False + -- HasNestedVec (Rec0 Double) = 'False + -- HasNestedVec (Rec0 Float128) = 'False + -- HasNestedVec (Rec0 (Vec _ _)) = 'True + -- HasNestedVec (Rec0 a) = 'False + HasNestedVec (Rec0 a) = Stuck (Rep a) 'False (HasNestedVec (Rep a)) + +-- type family NoNestedVec (f :: k -> Type) :: Constraint where +-- NoNestedVec (f :+: g) = (NoNestedVec f, NoNestedVec g) +-- NoNestedVec (f :*: g) = (NoNestedVec f, NoNestedVec g) +-- NoNestedVec (M1 _ _ a) = NoNestedVec a +-- NoNestedVec U1 = NoTypeError +-- NoNestedVec (Rec0 Int) = NoTypeError +-- NoNestedVec (Rec0 Int8) = NoTypeError +-- NoNestedVec (Rec0 Int16) = NoTypeError +-- NoNestedVec (Rec0 Int32) = NoTypeError +-- NoNestedVec (Rec0 Int64) = NoTypeError +-- NoNestedVec (Rec0 Int128) = NoTypeError +-- NoNestedVec (Rec0 Word) = NoTypeError +-- NoNestedVec (Rec0 Word8) = NoTypeError +-- NoNestedVec (Rec0 Word16) = NoTypeError +-- NoNestedVec (Rec0 Word32) = NoTypeError +-- NoNestedVec (Rec0 Word64) = NoTypeError +-- NoNestedVec (Rec0 Word128) = NoTypeError +-- NoNestedVec (Rec0 Half) = NoTypeError +-- NoNestedVec (Rec0 Float) = NoTypeError +-- NoNestedVec (Rec0 Double) = NoTypeError +-- NoNestedVec (Rec0 Float128) = NoTypeError +-- NoNestedVec (Rec0 (Vec _ _)) = TypeError (NestedVecError Int) +-- NoNestedVec (Rec0 a) = Stuck (Rep a) NoTypeError (NoNestedVec (Rep a)) + +-- NoNestedVec _ = NoTypeError + -- NoNestedVec (K1 R a) = NoNestedVec (Rep a) + -- NoNestedVec (K1 R a) = Stuck (Rep a) (NoGeneric a) (NoNestedVec (Rep a)) + +-- type family IsVec a :: Bool where +-- IsVec (Vec n a) = 'True +-- IsVec _ = 'False + +-- foo :: forall a. Stuck (Rep a) (NoGeneric a) NoTypeError => () +foo :: forall a. NoNestedVec a => () +foo = () + +data T x +type family Any :: k + +type family Stuck (f :: Type -> Type) (c :: Bool) (a :: k) :: k where + Stuck T _ _ = Any + Stuck _ _ k = k +type family NoGeneric t where + NoGeneric x = TypeError ('Text "No instance for " ':<>: 'ShowType (Generic x)) -type VecElt a = (Elt a, Prim a, IsSingle a, EltR a ~ a) +-- type family NoSIMD n t where +-- NoSIMD n t = TypeError ('Text "No instance for " ':<>: 'ShowType (SIMD n t)) -instance (KnownNat n, VecElt a) => Elt (Vec n a) where - type EltR (Vec n a) = Vec n a - eltR = TupRsingle (VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType)) - tagsR = [TagRsingle (VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType))] - toElt = id - fromElt = id +type family NestedVecError (t :: k) :: ErrorMessage where + NestedVecError t = 'Text "Can not derive SIMD class" + ':$$: 'Text "Because '" ':<>: 'ShowType t ':<>: 'Text "' already contains a SIMD vector." +--} diff --git a/src/Data/Array/Accelerate/Test/NoFib/Base.hs b/src/Data/Array/Accelerate/Test/NoFib/Base.hs index 1307c6cbb..0b39afb9e 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Base.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Base.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} @@ -14,21 +16,21 @@ module Data.Array.Accelerate.Test.NoFib.Base where +import Data.Array.Accelerate.Type import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Array import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Sugar.Shape import Data.Array.Accelerate.Trafo.Sharing -import Data.Array.Accelerate.Type -import Data.Primitive.Vec - import Control.Monad -import Data.Primitive.Types import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range +import qualified GHC.Exts as GHC + type Run = forall a. Arrays a => Acc a -> a type RunN = forall f. Afunction f => f -> AfunctionR f @@ -94,21 +96,21 @@ f32 = Gen.float (Range.linearFracFrom 0 (-log_flt_max) log_flt_max) f64 :: Gen Double f64 = Gen.double (Range.linearFracFrom 0 (-log_flt_max) log_flt_max) -v2 :: Prim a => Gen a -> Gen (Vec2 a) -v2 a = Vec2 <$> a <*> a +v2 :: (Elt a, SIMD 2 a) => Gen a -> Gen (V2 a) +v2 a = GHC.fromList <$> replicateM 2 a + +v3 :: (Elt a, SIMD 3 a) => Gen a -> Gen (V3 a) +v3 a = GHC.fromList <$> replicateM 3 a -v3 :: Prim a => Gen a -> Gen (Vec3 a) -v3 a = Vec3 <$> a <*> a <*> a +v4 :: (Elt a, SIMD 4 a) => Gen a -> Gen (V4 a) +v4 a = GHC.fromList <$> replicateM 4 a -v4 :: Prim a => Gen a -> Gen (Vec4 a) -v4 a = Vec4 <$> a <*> a <*> a <*> a +v8 :: (Elt a, SIMD 8 a) => Gen a -> Gen (V8 a) +v8 a = GHC.fromList <$> replicateM 8 a -v8 :: Prim a => Gen a -> Gen (Vec8 a) -v8 a = Vec8 <$> a <*> a <*> a <*> a <*> a <*> a <*> a <*> a +v16 :: (Elt a, SIMD 16 a) => Gen a -> Gen (V16 a) +v16 a = GHC.fromList <$> replicateM 16 a -v16 :: Prim a => Gen a -> Gen (Vec16 a) -v16 a = Vec16 <$> a <*> a <*> a <*> a <*> a <*> a <*> a <*> a - <*> a <*> a <*> a <*> a <*> a <*> a <*> a <*> a log_flt_max :: RealFloat a => a log_flt_max = log flt_max diff --git a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs index 7537d5625..af69b5b28 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Imaginary/SAXPY.hs @@ -23,12 +23,12 @@ module Data.Array.Accelerate.Test.NoFib.Imaginary.SAXPY ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Array as S -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Array as S +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue102.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue102.hs index dedec6cf6..99960adb0 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue102.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue102.hs @@ -1,5 +1,6 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Issues.Issue102 -- Copyright : [2009..2020] The Accelerate Team @@ -40,29 +41,29 @@ test1 = rts = 1 rustride = 1 - v = fill (constant (Z:.(p-1))) (constant 2) - ru' = fill (constant (Z:.(p-1))) (constant 1) + v = fill (I1 (p-1)) 2 + ru' = fill (I1 (p-1)) 1 -- generate a vector with phi(p)=p-1 elements - x' = reshape (constant (Z :. lts :. (p-1) :. rts)) v + x' = reshape (I3 lts (p-1) rts) v --embed into a vector of length p - y = generate (constant (Z :. lts :. p :. rts)) - (\ix -> let (Z :. l :. i :. r) = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int - in i A.== 0 ? (0, x' ! (lift $ Z :. l :. i-1 :. r))) + y = generate (I3 lts p rts) + (\ix -> let I3 l i r = ix + in i A.== 0 ? (0, x' ! (I3 l (i-1) r))) -- do a DFT_p - y' = reshape (constant (Z :. lts :. p :. rts)) (flatten y) - dftrus = generate (constant (Z :. p :. p)) - (\ix -> let (Z :. i :. j) = unlift ix :: Z :. Exp Int :. Exp Int - in ru' ! (lift (Z :. (i*j*rustride `mod` (constant p))))) + y' = reshape (I3 lts p rts) (flatten y) + dftrus = generate (I2 p p) + (\ix -> let I2 i j = ix + in ru' ! (I1 (i*j*rustride `mod` p))) - tensorDFTCoeffs = A.replicate (lift (Z:.lts:.All:.rts:.All)) dftrus + tensorDFTCoeffs = A.replicate @(Z :. Int :. All :. Int :. All) (I4 lts All rts All) dftrus tensorInputCoeffs = generate (shape tensorDFTCoeffs) - (\ix -> let (Z:.l:._:.r:.col) = unlift ix :: Z :. Exp Int :. Exp Int :. Exp Int :. Exp Int - in y' ! (lift $ Z:.l:.col:.r)) + (\ix -> let I4 l _ r col = ix + in y' ! (I3 l col r)) - dftans = flatten $ fold (+) (constant 0) $ A.zipWith (*) tensorDFTCoeffs tensorInputCoeffs + dftans = flatten $ fold (+) 0 $ A.zipWith (*) tensorDFTCoeffs tensorInputCoeffs --continue the alternate transform, but this line breaks dfty = reshape (shape y) $ dftans diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs index 444687f44..0caa35b2b 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue137.hs @@ -24,6 +24,7 @@ import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit +import Prelude hiding ( Maybe(..) ) test_issue137 :: RunN -> TestTree @@ -38,19 +39,19 @@ test1 :: Acc (Vector (Int,Int)) test1 = let sz = 3000 :: Int - interm_arrA = use $ A.fromList (Z :. sz) [ 8 - (a `mod` 17) | a <- [1..sz]] + interm_arrA = use $ A.fromList (Z :. sz) [ 8 - (a `mod` 17) | a <- [1..sz]] :: Acc (Vector Int) msA = use $ A.fromList (Z :. sz) [ (a `div` 8) | a <- [1..sz]] inf = 10000 :: Exp Int - infsA = A.generate (index1 (384 :: Exp Int)) (\_ -> lift (inf,inf)) - inpA = A.map (\v -> lift (abs v,inf) :: Exp (Int,Int)) interm_arrA + infsA = A.generate (index1 (384 :: Exp Int)) (\_ -> T2 inf inf) + inpA = A.map (\v -> T2 (abs v) inf) interm_arrA in - A.permute (\a12 b12 -> let (a1,a2) = unlift a12 - (b1,b2) = unlift b12 + A.permute (\a12 b12 -> let T2 a1 a2 = a12 + T2 b1 b2 = b12 in (a1 A.<= b1) - ? ( lift (a1, A.min a2 b1) - , lift (b1, A.min b2 a1) + ? ( T2 a1 (A.min a2 b1) + , T2 b1 (A.min b2 a1) )) infsA - (\ix -> Just_ (index1 (msA A.! ix))) + (\ix -> Just (index1 (msA A.! ix))) inpA diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue185.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue185.hs index f14f83520..53c834915 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue185.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue185.hs @@ -28,7 +28,7 @@ import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit -import Prelude as P +import Prelude as P hiding ( Maybe(..) ) test_issue185 :: RunN -> TestTree @@ -144,6 +144,6 @@ scatterIf -> Acc (Vector e') scatterIf to maskV p def input = permute const def pf input' where - pf ix = p (maskV ! ix) ? ( Just_ (index1 (to ! ix)), Nothing_ ) + pf ix = p (maskV ! ix) ? ( Just (index1 (to ! ix)), Nothing ) input' = backpermute (shape to `intersect` shape input) id input diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs index afdbe9294..187984df7 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue264.hs @@ -26,12 +26,12 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue264 ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Array as S -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Array as S +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen @@ -80,7 +80,7 @@ test_not_not -> Property test_not_not runN = property $ do - xs <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) let !go = runN (A.map A.not . A.map A.not) in go xs === mapRef (P.not . P.not) xs test_not_and @@ -88,8 +88,8 @@ test_not_and -> Property test_not_and runN = property $ do - xs <- forAll (array Z Gen.bool) - ys <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) + ys <- forAll (array @Z Z Gen.bool) let !go = runN (A.zipWith (\u v -> A.not (u A.&& v))) in go xs ys === zipWithRef (\u v -> P.not (u P.&& v)) xs ys test_not_or @@ -97,8 +97,8 @@ test_not_or -> Property test_not_or runN = property $ do - xs <- forAll (array Z Gen.bool) - ys <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) + ys <- forAll (array @Z Z Gen.bool) let !go = runN (A.zipWith (\u v -> A.not (u A.|| v))) in go xs ys === zipWithRef (\u v -> P.not (u P.|| v)) xs ys test_not_not_and @@ -106,8 +106,8 @@ test_not_not_and -> Property test_not_not_and runN = property $ do - xs <- forAll (array Z Gen.bool) - ys <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) + ys <- forAll (array @Z Z Gen.bool) let !go = runN (A.zipWith (\u v -> A.not (A.not (u A.&& v)))) in go xs ys === zipWithRef (\u v -> P.not (P.not (u P.&& v))) xs ys test_not_not_or @@ -115,8 +115,8 @@ test_not_not_or -> Property test_not_not_or runN = property $ do - xs <- forAll (array Z Gen.bool) - ys <- forAll (array Z Gen.bool) + xs <- forAll (array @Z Z Gen.bool) + ys <- forAll (array @Z Z Gen.bool) let !go = runN (A.zipWith (\u v -> A.not (A.not (u A.|| v)))) in go xs ys === zipWithRef (\u v -> P.not (P.not (u P.|| v))) xs ys test_neg_neg @@ -133,7 +133,7 @@ test_neg_neg runN e = mapRef :: (Shape sh, Elt a, Elt b) => (a -> b) -> Array sh a -> Array sh b -mapRef f xs = fromFunction (S.shape xs) (\ix -> f (xs S.! ix)) +mapRef f xs = fromFunction (arrayShape xs) (\ix -> f (xs S.! ix)) zipWithRef :: (Shape sh, Elt a, Elt b, Elt c) @@ -143,6 +143,6 @@ zipWithRef -> Array sh c zipWithRef f xs ys = fromFunction - (S.shape xs `S.intersect` S.shape ys) + (arrayShape xs `S.intersect` arrayShape ys) (\ix -> f (xs S.! ix) (ys S.! ix)) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue288.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue288.hs index 84c18a7d5..101d5bede 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue288.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue288.hs @@ -23,7 +23,7 @@ import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit -import Prelude as P +import Prelude as P hiding ( Bool(..) ) test_issue288 :: RunN -> TestTree @@ -31,7 +31,7 @@ test_issue288 runN = testCase "288" $ xs @=? runN (A.map f) xs f :: Exp (Int, Int) -> Exp (Int, Int) -f e = while (const (lift False)) id e +f e = while (const False) id e xs :: Vector (Int, Int) xs = fromList (Z:.10) (P.zip [1..] [1..]) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs index ddc167342..8365a70eb 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue364.hs @@ -23,17 +23,16 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue364 ( ) where -import Prelude ( fromInteger, show ) -import qualified Prelude as P +import Prelude ( show ) +import qualified Prelude as P -import Data.Array.Accelerate hiding ( fromInteger ) -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog - import Test.Tasty import Test.Tasty.HUnit @@ -55,8 +54,8 @@ test_issue364 runN = -> TestTree testElt _ = testGroup (show (eltR @e)) - [ testCase "A" $ expectedArray @_ @e Z 64 @=? runN (scanl iappend one) (intervalArray Z 64) - , testCase "B" $ expectedArray @_ @e Z 65 @=? runN (scanl iappend one) (intervalArray Z 65) -- failed for integral types + [ testCase "A" $ expectedArray @Z @e Z 64 @=? runN (scanl iappend one) (intervalArray Z 64) + , testCase "B" $ expectedArray @Z @e Z 65 @=? runN (scanl iappend one) (intervalArray Z 65) -- failed for integral types ] diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs index 1f42aa3f9..5ac158df3 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue407.hs @@ -1,10 +1,12 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Issues.Issue407 -- Copyright : [2009..2020] The Accelerate Team @@ -24,10 +26,11 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue407 ( ) where -import Prelude as P +import Prelude as P hiding ( Bool(..) ) import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Elt as S +import Data.Array.Accelerate.AST ( BitOrMask, PrimBool ) +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty @@ -42,16 +45,18 @@ test_issue407 runN = ] where testElt - :: forall a. (Show a, P.Fractional a, A.RealFloat a) + :: forall a. (Show a, P.Fractional a, A.RealFloat a, BitOrMask (EltR a) ~ PrimBool) => TestTree testElt = + let xs :: Vector a + xs = [0/0, -2/0, -0/0, 0.1, 1/0, 0.5, 5/0] + + eNaN, eInf :: Vector Bool + eNaN = [True, False, True, False, False, False, False] -- expected: isNaN + eInf = [False, True, False, False, True, False, True] -- expected: isInfinite + in testGroup (show (eltR @a)) [ testCase "isNaN" $ eNaN @=? runN (A.map A.isNaN) xs , testCase "isInfinite" $ eInf @=? runN (A.map A.isInfinite) xs ] - where - xs :: Vector a - xs = [0/0, -2/0, -0/0, 0.1, 1/0, 0.5, 5/0] - eNaN = [True, False, True, False, False, False, False] -- expected: isNaN - eInf = [False, True, False, False, True, False, True] -- expected: isInfinite diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue427.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue427.hs index 1988d8c16..17a40b098 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue427.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue427.hs @@ -41,6 +41,7 @@ test_issue427 runN , testProperty "n-by-m" $ test_indicesOfTruth runN dim2 ] where + by :: Int -> Gen DIM2 by x = do y <- Gen.int (Range.linear 0 1024) pure (Z :. y :. x) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue436.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue436.hs index 57420f447..3663bf9bc 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue436.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue436.hs @@ -18,11 +18,12 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue436 ( ) where -import Data.Array.Accelerate as A +import Data.Array.Accelerate as A import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit +import Prelude hiding ( Bool(..) ) test_issue436 :: RunN -> TestTree diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs index 956b6956e..a4d516deb 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue437.hs @@ -73,7 +73,10 @@ test_issue437 runN (alloc P.> 0 P.&& to P.== 4 P.&& from P.== 0) -- remote memory space where xs :: (Scalar Float, Matrix Float) - xs = runN $ T2 (unit 42) (fill (constant $ Z:.10000:.10000) 1) + xs = runN g + where + g :: Acc (Scalar Float, Matrix Float) + g = T2 (unit 42) (fill (Z :. 10000 :. 10000) 1) go :: Arrays a => a -> a go = runN f diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue439.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue439.hs index 3bfa18b40..d6f9b1a81 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue439.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue439.hs @@ -33,5 +33,5 @@ e1 :: Scalar Float e1 = fromList Z [2] t1 :: Acc (Scalar Float) -t1 = compute . A.map (* 2) . compute $ fill Z_ 1 +t1 = compute . A.map (* 2) . compute $ fill Z 1 diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue517.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue517.hs index 0d92a9400..17619264e 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue517.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue517.hs @@ -17,13 +17,15 @@ module Data.Array.Accelerate.Test.NoFib.Issues.Issue517 ( ) where -import Data.Array.Accelerate as A -import Data.Array.Accelerate.Data.Semigroup as A +import Data.Array.Accelerate as A +import Data.Array.Accelerate.Data.Semigroup as A import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit +import Prelude hiding ( Maybe(..) ) + test_issue517 :: RunN -> TestTree test_issue517 runN @@ -37,9 +39,9 @@ e1 = fromList Z [(Nothing, Just 2, Just 3, Just 5, Just 7)] t1 :: Acc (Scalar (Tup5 (Maybe (Max Float)))) t1 = unit $ - T5 (Nothing_ <> Nothing_) - (Nothing_ <> Just_ 2) - (Just_ 3 <> Nothing_) - (Just_ 4 <> Just_ 5) - (Just_ 7 <> Just_ 6) + T5 (Nothing <> Nothing) + (Nothing <> Just 2) + (Just 3 <> Nothing) + (Just 4 <> Just 5) + (Just 7 <> Just 6) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue93.hs b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue93.hs index 556b27ffb..a2bd5fdf8 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue93.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Issues/Issue93.hs @@ -22,6 +22,7 @@ import Data.Array.Accelerate.Test.NoFib.Base import Test.Tasty import Test.Tasty.HUnit +import Prelude hiding ( Maybe(..) ) test_issue93 :: RunN -> TestTree @@ -32,7 +33,7 @@ xs :: Array DIM2 Int xs = fromList (Z :. 1 :. 1) [5] test1 :: Acc (Array DIM2 Int) -test1 = permute (\c _ -> c) (fill (shape xs') (constant 0)) Just_ xs' +test1 = permute (\c _ -> c) (fill (shape xs') (constant 0)) Just xs' where xs' = use xs diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs index 32bad367e..8620d2700 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Backpermute.hs @@ -24,12 +24,11 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Backpermute ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Array as S -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen @@ -71,7 +70,7 @@ test_backpermute runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "take" $ test_take runN sh e , testProperty "drop" $ test_drop runN sh e @@ -127,7 +126,7 @@ test_gather runN dim dim' e = -- let !go = runN $ \i -> A.backpermute (A.shape i) (i A.!) -- - go ix xs ~~~ backpermuteRef sh' (ix S.!) xs + go ix xs ~~~ backpermuteRef sh' (ix `indexArray`) xs scalar :: Elt e => e -> Scalar e @@ -140,7 +139,7 @@ backpermuteRef -> Array sh e -> Array sh' e backpermuteRef sh' p arr = - fromFunction sh' (\ix -> arr S.! p ix) + fromFunction sh' (\ix -> arr `indexArray` p ix) takeRef :: (Shape sh, Slice sh, Elt e) @@ -148,8 +147,8 @@ takeRef -> Array (sh:.Int) e -> Array (sh:.Int) e takeRef n arr = - let sh :. m = S.shape arr - in fromFunction (sh :. P.min m n) (arr S.!) + let sh :. m = arrayShape arr + in fromFunction (sh :. P.min m n) (arr `indexArray`) dropRef :: (Shape sh, Slice sh, Elt e) @@ -157,7 +156,7 @@ dropRef -> Array (sh:.Int) e -> Array (sh:.Int) e dropRef n arr = - let sh :. m = S.shape arr + let sh :. m = arrayShape arr n' = P.max 0 n - in fromFunction (sh :. P.max 0 (m - n')) (\(sz:.i) -> arr S.! (sz :. i+n')) + in fromFunction (sh :. P.max 0 (m - n')) (\(sz:.i) -> arr `indexArray` (sz :. i+n')) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs index cfafc2da1..164f598ed 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Filter.hs @@ -26,12 +26,12 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Filter ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Array as S -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Array as S +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog @@ -71,7 +71,7 @@ test_filter runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "even" $ test_even runN sh e ] @@ -91,7 +91,7 @@ test_filter runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "positive" $ test_positive runN sh e ] diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs index 6e19ee3dc..6dce22603 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Fold.hs @@ -25,11 +25,11 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Fold ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen @@ -72,7 +72,7 @@ test_fold runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_sum runN sh e e @@ -112,7 +112,7 @@ test_foldSeg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_segmented_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_segmented_sum runN sh e e @@ -187,7 +187,7 @@ test_segmented_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) xs <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.foldSeg (+) (the v)) in go (scalar x) xs seg ~~~ foldSegRef (+) x xs seg @@ -202,7 +202,7 @@ test_segmented_minimum runN dim e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (S.rank @sh))))) xs <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.fold1Seg A.min) in go xs seg ~~~ fold1SegRef P.min xs seg @@ -217,7 +217,7 @@ test_segmented_maximum runN dim e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (S.rank @sh))))) xs <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.fold1Seg A.max) in go xs seg ~~~ fold1SegRef P.max xs seg diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs index dec03973d..71aa3b2a4 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Map.hs @@ -61,7 +61,7 @@ test_map runN = testIntegralElt :: forall a. ( P.Integral a, P.FiniteBits a , A.Integral a, A.FiniteBits a - , A.FromIntegral a Double + , A.FromIntegral a Int, A.FromIntegral a Double , Similar a, Show a ) => Gen a -> TestTree @@ -94,7 +94,7 @@ test_map runN = ] testFloatingElt - :: forall a. (P.RealFloat a, A.Floating a, A.RealFrac a, Similar a, Show a) + :: forall a. (P.RealFloat a, A.Floating a, A.RealFrac a, Similar a, Show a, FromIntegral (Significand a) Int) => (Range a -> Gen a) -> TestTree testFloatingElt e = @@ -194,7 +194,7 @@ test_complement runN dim e = let !go = runN (A.map A.complement) in go xs ~~~ mapRef P.complement xs test_popCount - :: (Shape sh, Show sh, Show e, A.Bits e, P.Bits e, P.Eq sh) + :: (Shape sh, Show sh, Show e, A.Bits e, P.Bits e, P.Eq sh, A.Ord e, A.Integral e, A.FromIntegral e Int) => RunN -> Gen sh -> Gen e @@ -203,10 +203,10 @@ test_popCount runN dim e = property $ do sh <- forAll dim xs <- forAll (array sh e) - let !go = runN (A.map A.popCount) in go xs ~~~ mapRef P.popCount xs + let !go = runN (A.map (A.fromIntegral . A.popCount)) in go xs ~~~ mapRef P.popCount xs test_countLeadingZeros - :: (Shape sh, Show sh, Show e, A.FiniteBits e, P.FiniteBits e, P.Eq sh) + :: (Shape sh, Show sh, Show e, A.FiniteBits e, P.FiniteBits e, P.Eq sh, A.Ord e, A.Integral e, A.FromIntegral e Int) => RunN -> Gen sh -> Gen e @@ -215,10 +215,10 @@ test_countLeadingZeros runN dim e = property $ do sh <- forAll dim xs <- forAll (array sh e) - let !go = runN (A.map A.countLeadingZeros) in go xs ~~~ mapRef countLeadingZerosRef xs + let !go = runN (A.map (A.fromIntegral . A.countLeadingZeros)) in go xs ~~~ mapRef countLeadingZerosRef xs test_countTrailingZeros - :: (Shape sh, Show sh, Show e, A.FiniteBits e, P.FiniteBits e, P.Eq sh) + :: (Shape sh, Show sh, Show e, A.FiniteBits e, P.FiniteBits e, P.Eq sh, A.Ord e, A.Integral e, A.FromIntegral e Int) => RunN -> Gen sh -> Gen e @@ -227,7 +227,7 @@ test_countTrailingZeros runN dim e = property $ do sh <- forAll dim xs <- forAll (array sh e) - let !go = runN (A.map A.countTrailingZeros) in go xs ~~~ mapRef countTrailingZerosRef xs + let !go = runN (A.map (A.fromIntegral . A.countTrailingZeros)) in go xs ~~~ mapRef countTrailingZerosRef xs test_fromIntegral :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.Integral e, A.Integral e, A.FromIntegral e Double) @@ -398,7 +398,7 @@ test_log runN dim e = let !go = runN (A.map log) in go xs ~~~ mapRef log xs test_truncate - :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e) + :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e, FromIntegral (Significand e) Int) => RunN -> Gen sh -> Gen e @@ -410,7 +410,7 @@ test_truncate runN dim e = let !go = runN (A.map A.truncate) in go xs ~~~ mapRef (P.truncate :: e -> Int) xs test_round - :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e) + :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e, FromIntegral (Significand e) Int) => RunN -> Gen sh -> Gen e @@ -422,7 +422,7 @@ test_round runN dim e = let !go = runN (A.map A.round) in go xs ~~~ mapRef (P.round :: e -> Int) xs test_floor - :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e) + :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e, FromIntegral (Significand e) Int) => RunN -> Gen sh -> Gen e @@ -434,7 +434,7 @@ test_floor runN dim e = let !go = runN (A.map A.floor) in go xs ~~~ mapRef (P.floor :: e -> Int) xs test_ceiling - :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e) + :: forall sh e. (Shape sh, Show sh, Show e, P.Eq sh, P.RealFrac e, A.RealFrac e, FromIntegral (Significand e) Int) => RunN -> Gen sh -> Gen e diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs index 91ebc7484..13131a118 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Permute.hs @@ -39,7 +39,7 @@ import Test.Tasty import Test.Tasty.Hedgehog import System.IO.Unsafe -import Prelude as P +import Prelude as P hiding ( Bool(..), Maybe(..) ) import qualified Data.Set as Set diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs index 525f9e2c7..a7ea66761 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/SIMD.hs @@ -1,9 +1,11 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Prelude.SIMD -- Copyright : [2009..2020] The Accelerate Team @@ -20,8 +22,8 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.SIMD ( ) where -import Lens.Micro ( _1, _2, _3, _4 ) -import Lens.Micro.Extras ( view ) +import Lens.Micro ( _1, _2, _3, _4 ) +import Lens.Micro.Extras ( view ) import Prelude as P import Data.Array.Accelerate as A @@ -30,8 +32,6 @@ import Data.Array.Accelerate.Sugar.Elt as S import Data.Array.Accelerate.Sugar.Shape as S import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config -import Data.Primitive.Vec -import Data.Primitive.Types import Hedgehog import qualified Hedgehog.Gen as Gen @@ -56,7 +56,7 @@ test_simd runN = , at @TestDouble $ testElt f64 ] where - testElt :: forall e. (VecElt e, P.Eq e, Show e) + testElt :: forall e. (Elt e, SIMD 2 e, SIMD 3 e, SIMD 4 e, P.Eq e, Show e) => Gen e -> TestTree testElt e = @@ -65,7 +65,7 @@ test_simd runN = , testInject e ] - testExtract :: forall e. (VecElt e, P.Eq e, Show e) + testExtract :: (Elt e, SIMD 2 e, SIMD 3 e, SIMD 4 e, P.Eq e, Show e) => Gen e -> TestTree testExtract e = @@ -75,7 +75,7 @@ test_simd runN = , testProperty "V4" $ test_extract_v4 runN dim1 e ] - testInject :: forall e. (VecElt e, P.Eq e, Show e) + testInject :: (Elt e, SIMD 2 e, SIMD 3 e, SIMD 4 e, P.Eq e, Show e) => Gen e -> TestTree testInject e = @@ -87,7 +87,7 @@ test_simd runN = test_extract_v2 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 2 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -100,7 +100,7 @@ test_extract_v2 runN dim e = let !go = runN (A.map (view _m . unpackVec2')) in go xs === mapRef (view _l . unpackVec2) xs test_extract_v3 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 3 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -113,7 +113,7 @@ test_extract_v3 runN dim e = let !go = runN (A.map (view _m . unpackVec3')) in go xs === mapRef (view _l . unpackVec3) xs test_extract_v4 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 4 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -126,7 +126,7 @@ test_extract_v4 runN dim e = let !go = runN (A.map (view _m . unpackVec4')) in go xs === mapRef (view _l . unpackVec4) xs test_inject_v2 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 2 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -137,10 +137,10 @@ test_inject_v2 runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 e) - let !go = runN (A.zipWith A.V2) in go xs ys === zipWithRef Vec2 xs ys + let !go = runN (A.zipWith V2) in go xs ys === zipWithRef V2 xs ys test_inject_v3 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 3 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -153,10 +153,10 @@ test_inject_v3 runN dim e = xs <- forAll (array sh1 e) ys <- forAll (array sh2 e) zs <- forAll (array sh3 e) - let !go = runN (A.zipWith3 A.V3) in go xs ys zs === zipWith3Ref Vec3 xs ys zs + let !go = runN (A.zipWith3 V3) in go xs ys zs === zipWith3Ref V3 xs ys zs test_inject_v4 - :: (Shape sh, Show sh, Show e, VecElt e, P.Eq e, P.Eq sh) + :: (Shape sh, Show sh, Show e, Elt e, SIMD 4 e, P.Eq e, P.Eq sh) => RunN -> Gen sh -> Gen e @@ -171,26 +171,26 @@ test_inject_v4 runN dim e = ys <- forAll (array sh2 e) zs <- forAll (array sh3 e) ws <- forAll (array sh4 e) - let !go = runN (A.zipWith4 A.V4) in go xs ys zs ws === zipWith4Ref Vec4 xs ys zs ws + let !go = runN (A.zipWith4 V4) in go xs ys zs ws === zipWith4Ref V4 xs ys zs ws -unpackVec2 :: Prim e => Vec2 e -> (e, e) -unpackVec2 (Vec2 a b) = (a, b) +unpackVec2 :: (Elt e, SIMD 2 e) => V2 e -> (e, e) +unpackVec2 (V2 a b) = (a, b) -unpackVec3 :: Prim e => Vec3 e -> (e, e, e) -unpackVec3 (Vec3 a b c) = (a, b, c) +unpackVec3 :: (Elt e, SIMD 3 e) => V3 e -> (e, e, e) +unpackVec3 (V3 a b c) = (a, b, c) -unpackVec4 :: Prim e => Vec4 e -> (e, e, e, e) -unpackVec4 (Vec4 a b c d) = (a, b, c, d) +unpackVec4 :: (Elt e, SIMD 4 e) => V4 e -> (e, e, e, e) +unpackVec4 (V4 a b c d) = (a, b, c, d) -unpackVec2' :: VecElt e => Exp (Vec2 e) -> (Exp e, Exp e) -unpackVec2' (A.V2 a b) = (a, b) +unpackVec2' :: (Elt e, SIMD 2 e) => Exp (V2 e) -> (Exp e, Exp e) +unpackVec2' (V2 a b) = (a, b) -unpackVec3' :: VecElt e => Exp (Vec3 e) -> (Exp e, Exp e, Exp e) -unpackVec3' (A.V3 a b c) = (a, b, c) +unpackVec3' :: (Elt e, SIMD 3 e) => Exp (V3 e) -> (Exp e, Exp e, Exp e) +unpackVec3' (V3 a b c) = (a, b, c) -unpackVec4' :: VecElt e => Exp (Vec4 e) -> (Exp e, Exp e, Exp e, Exp e) -unpackVec4' (A.V4 a b c d) = (a, b, c, d) +unpackVec4' :: (Elt e, SIMD 4 e) => Exp (V4 e) -> (Exp e, Exp e, Exp e, Exp e) +unpackVec4' (V4 a b c d) = (a, b, c, d) -- Reference Implementation diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs index f1a0f587d..4b94704c9 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Scan.hs @@ -30,11 +30,11 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Scan ( import Prelude as P import Data.Array.Accelerate as A -import Data.Array.Accelerate.Sugar.Elt as S -import Data.Array.Accelerate.Sugar.Shape as S +import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar +import qualified Data.Array.Accelerate.Sugar.Shape as S import Hedgehog import qualified Hedgehog.Gen as Gen @@ -76,7 +76,7 @@ test_scanl runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanl_sum runN sh e e , testProperty "non-commutative" $ test_scanl_interval runN sh e @@ -112,7 +112,7 @@ test_scanl1 runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl1_sum runN sh e , testProperty "non-commutative" $ test_scanl1_interval runN sh e ] @@ -147,7 +147,7 @@ test_scanl' runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl'_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanl'_sum runN sh e e , testProperty "non-commutative" $ test_scanl'_interval runN sh e @@ -183,7 +183,7 @@ test_scanr runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanr_sum runN sh e e , testProperty "non-commutative" $ test_scanr_interval runN sh e @@ -219,7 +219,7 @@ test_scanr1 runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr1_sum runN sh e , testProperty "non-commutative" $ test_scanr1_interval runN sh e ] @@ -254,7 +254,7 @@ test_scanr' runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr'_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanr'_sum runN sh e e , testProperty "non-commutative" $ test_scanr'_interval runN sh e @@ -290,7 +290,7 @@ test_scanlSeg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanlSeg_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanlSeg_sum runN sh e e ] @@ -325,7 +325,7 @@ test_scanl1Seg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl1Seg_sum runN sh e ] @@ -359,7 +359,7 @@ test_scanl'Seg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanl'Seg_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanl'Seg_sum runN sh e e ] @@ -394,7 +394,7 @@ test_scanrSeg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanrSeg_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanrSeg_sum runN sh e e ] @@ -429,7 +429,7 @@ test_scanr1Seg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr1Seg_sum runN sh e ] @@ -463,7 +463,7 @@ test_scanr'Seg runN = => Gen (sh:.Int) -> TestTree testDim sh = - testGroup ("DIM" P.++ show (rank @(sh:.Int))) + testGroup ("DIM" P.++ show (S.rank @(sh:.Int))) [ testProperty "sum" $ test_scanr'Seg_sum runN sh (pure 0) e , testProperty "non-neutral sum" $ test_scanr'Seg_sum runN sh e e ] @@ -636,7 +636,7 @@ test_scanlSeg_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanlSeg (+) (the v)) in go (scalar x) arr seg ~~~ scanlSegRef (+) x arr seg @@ -651,7 +651,7 @@ test_scanl1Seg_sum runN dim e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 1 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.scanl1Seg (+)) in go arr seg ~~~ scanl1SegRef (+) arr seg @@ -668,7 +668,7 @@ test_scanl'Seg_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanl'Seg (+) (the v)) in go (scalar x) arr seg ~~~ scanl'SegRef (+) x arr seg @@ -685,7 +685,7 @@ test_scanrSeg_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanrSeg (+) (the v)) in go (scalar x) arr seg ~~~ scanrSegRef (+) x arr seg @@ -700,7 +700,7 @@ test_scanr1Seg_sum runN dim e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 1 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.scanr1Seg (+)) in go arr seg ~~~ scanr1SegRef (+) arr seg @@ -717,7 +717,7 @@ test_scanr'Seg_sum runN dim z e = sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- pure (P.min n1 n2) -- don't generate too many segments - seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank @sh))))) + seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (S.rank @sh))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanr'Seg (+) (the v)) in go (scalar x) arr seg ~~~ scanr'SegRef (+) x arr seg diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs index dec121c68..54ee71e7e 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/Stencil.hs @@ -25,7 +25,7 @@ module Data.Array.Accelerate.Test.NoFib.Prelude.Stencil ( ) where import Data.Typeable -import Prelude as P +import Prelude as P hiding ( Maybe(..), Either(..) ) import Data.Array.Accelerate as A import Data.Array.Accelerate.Sugar.Elt as S @@ -635,7 +635,7 @@ bound bnd sh0 ix0 = go TupRunit () () = Right () go (TupRpair tsh tsz) (sh,sz) (ih,iz) = go tsh sh ih `addDim` go tsz sz iz go (TupRsingle t) sh i - | Just Refl <- matchScalarType t (scalarType :: ScalarType Int) + | Just Refl <- matchScalarType t (scalarType :: ScalarType INT) = if i P.< 0 then case bnd of Clamp -> Right 0 @@ -655,6 +655,7 @@ bound bnd sh0 ix0 = | otherwise = error "bound: expected shape with Int dimensions" + addDim :: Either e ds -> Either e d -> Either e (ds, d) Right ds `addDim` Right d = Right (ds, d) _ `addDim` Left e = Left e Left e `addDim` _ = Left e diff --git a/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs b/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs index b73785db2..e04b1ca00 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Prelude/ZipWith.hs @@ -61,7 +61,7 @@ test_zipWith runN = where testIntegralElt :: forall a. ( P.Integral a, P.FiniteBits a - , A.Integral a, A.FiniteBits a + , A.Integral a, A.FiniteBits a, A.FromIntegral Int a , Similar a, Show a ) => Gen a -> TestTree @@ -93,10 +93,10 @@ test_zipWith runN = , testProperty "(.&.)" $ test_band runN sh e , testProperty "(.|.)" $ test_bor runN sh e , testProperty "xor" $ test_xor runN sh e - , testProperty "shift" $ test_shift runN sh e + -- , testProperty "shift" $ test_shift runN sh e , testProperty "shiftL" $ test_shiftL runN sh e , testProperty "shiftR" $ test_shiftR runN sh e - , testProperty "rotate" $ test_rotate runN sh e + -- , testProperty "rotate" $ test_rotate runN sh e , testProperty "rotateL" $ test_rotateL runN sh e , testProperty "rotateR" $ test_rotateR runN sh e @@ -381,23 +381,23 @@ test_xor runN dim e = ys <- forAll (array sh2 e) let !go = runN (A.zipWith A.xor) in go xs ys ~~~ zipWithRef P.xor xs ys -test_shift - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) - => RunN - -> Gen sh - -> Gen e - -> Property -test_shift runN dim e = - property $ do - let s = P.finiteBitSize (undefined::e) - sh1 <- forAll dim - sh2 <- forAll dim - xs <- forAll (array sh1 e) - ys <- forAll (array sh2 (Gen.int (Range.linearFrom 0 (-s) s))) - let !go = runN (A.zipWith A.shift) in go xs ys ~~~ zipWithRef P.shift xs ys +-- test_shift +-- :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) +-- => RunN +-- -> Gen sh +-- -> Gen e +-- -> Property +-- test_shift runN dim e = +-- property $ do +-- let s = P.finiteBitSize (undefined::e) +-- sh1 <- forAll dim +-- sh2 <- forAll dim +-- xs <- forAll (array sh1 e) +-- ys <- forAll (array sh2 (Gen.int (Range.linearFrom 0 (-s) s))) +-- let !go = runN (A.zipWith A.shift) in go xs ys ~~~ zipWithRef P.shift xs ys test_shiftL - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) + :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e, A.FromIntegral Int e) => RunN -> Gen sh -> Gen e @@ -409,10 +409,10 @@ test_shiftL runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 (Gen.int (Range.linear 0 s))) - let !go = runN (A.zipWith A.shiftL) in go xs ys ~~~ zipWithRef P.shiftL xs ys + let !go = runN (A.zipWith (\x -> A.shiftL x . A.fromIntegral)) in go xs ys ~~~ zipWithRef P.shiftL xs ys test_shiftR - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) + :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e, A.FromIntegral Int e) => RunN -> Gen sh -> Gen e @@ -424,25 +424,25 @@ test_shiftR runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 (Gen.int (Range.linear 0 s))) - let !go = runN (A.zipWith A.shiftR) in go xs ys ~~~ zipWithRef P.shiftR xs ys - -test_rotate - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) - => RunN - -> Gen sh - -> Gen e - -> Property -test_rotate runN dim e = - property $ do - let s = P.finiteBitSize (undefined::e) - sh1 <- forAll dim - sh2 <- forAll dim - xs <- forAll (array sh1 e) - ys <- forAll (array sh2 (Gen.int (Range.linearFrom 0 (-s) s))) - let !go = runN (A.zipWith A.rotate) in go xs ys ~~~ zipWithRef P.rotate xs ys + let !go = runN (A.zipWith (\x -> A.shiftR x . A.fromIntegral)) in go xs ys ~~~ zipWithRef P.shiftR xs ys + +-- test_rotate +-- :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) +-- => RunN +-- -> Gen sh +-- -> Gen e +-- -> Property +-- test_rotate runN dim e = +-- property $ do +-- let s = P.finiteBitSize (undefined::e) +-- sh1 <- forAll dim +-- sh2 <- forAll dim +-- xs <- forAll (array sh1 e) +-- ys <- forAll (array sh2 (Gen.int (Range.linearFrom 0 (-s) s))) +-- let !go = runN (A.zipWith A.rotate) in go xs ys ~~~ zipWithRef P.rotate xs ys test_rotateL - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) + :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e, A.FromIntegral Int e) => RunN -> Gen sh -> Gen e @@ -454,10 +454,10 @@ test_rotateL runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 (Gen.int (Range.linear 0 s))) - let !go = runN (A.zipWith A.rotateL) in go xs ys ~~~ zipWithRef P.rotateL xs ys + let !go = runN (A.zipWith (\x -> A.rotateL x . A.fromIntegral)) in go xs ys ~~~ zipWithRef P.rotateL xs ys test_rotateR - :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e) + :: forall sh e. (Shape sh, Show sh, Similar e, Show e, P.Eq sh, P.FiniteBits e, A.FiniteBits e, A.FromIntegral Int e) => RunN -> Gen sh -> Gen e @@ -469,7 +469,7 @@ test_rotateR runN dim e = sh2 <- forAll dim xs <- forAll (array sh1 e) ys <- forAll (array sh2 (Gen.int (Range.linear 0 s))) - let !go = runN (A.zipWith A.rotateR) in go xs ys ~~~ zipWithRef P.rotateR xs ys + let !go = runN (A.zipWith (\x -> A.rotateR x . A.fromIntegral)) in go xs ys ~~~ zipWithRef P.rotateR xs ys test_lt :: (Shape sh, Show sh, Show e, P.Eq sh, P.Ord e, A.Ord e) diff --git a/src/Data/Array/Accelerate/Test/NoFib/Sharing.hs b/src/Data/Array/Accelerate/Test/NoFib/Sharing.hs index 7fada0655..722666522 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Sharing.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Sharing.hs @@ -85,13 +85,13 @@ mkArray n = use $ fromList (Z:.1) [n] test_blowup :: Int -> Acc (Array DIM1 Int) test_blowup 0 = (mkArray 0) -test_blowup n = A.map (\_ -> newArr ! (lift (Z:.(0::Int))) + - newArr ! (lift (Z:.(1::Int)))) (mkArray n) +test_blowup n = A.map (\_ -> newArr ! (Z :. 0) + + newArr ! (Z :. 1)) (mkArray n) where newArr = test_blowup (n-1) idx :: Int -> Exp DIM1 -idx i = lift (Z:.i) +idx i = constant (Z:.i) test_bfs :: Acc (Array DIM1 Int) test_bfs = A.map (\x -> (map2 ! (idx 1)) + (map1 ! (idx 2)) + x) arr diff --git a/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs b/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs index a65c4b3ef..2721e6820 100644 --- a/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs +++ b/src/Data/Array/Accelerate/Test/NoFib/Spectral/RadixSort.hs @@ -24,7 +24,7 @@ module Data.Array.Accelerate.Test.NoFib.Spectral.RadixSort ( import Data.Function import Data.List ( sortBy ) -import Prelude as P +import Prelude as P hiding ( Maybe(..) ) import qualified Data.Bits as P import Data.Array.Accelerate as A @@ -139,21 +139,21 @@ instance Radix Word64 where radix = radixOfUnsigned radixOfSigned - :: forall e. (Radix e, A.Bounded e, A.Integral e, A.FromIntegral e Int) + :: forall e. (Radix e, A.Bounded e, A.Integral e, A.FromIntegral Int e, A.FromIntegral e Int) => Exp Int -> Exp e -> Exp Int radixOfSigned i e = i A.== (passes' - 1) ? (radix' (e `xor` minBound), radix' e) where - radix' x = A.fromIntegral $ (x `A.shiftR` i) .&. 1 + radix' x = A.fromIntegral $ (x `A.shiftR` A.fromIntegral i) .&. 1 passes' = constant (passes (undefined :: e)) radixOfUnsigned - :: (Radix e, A.Integral e, A.FromIntegral e Int) + :: (Radix e, A.Integral e, A.FromIntegral Int e, A.FromIntegral e Int) => Exp Int -> Exp e -> Exp Int -radixOfUnsigned i e = A.fromIntegral $ (e `A.shiftR` i) .&. 1 +radixOfUnsigned i e = A.fromIntegral $ (e `A.shiftR` A.fromIntegral i) .&. 1 -- A simple (parallel) radix sort implementation [1]. @@ -176,7 +176,7 @@ radixsortBy rdx arr = foldr1 (>->) (P.map radixPass [0..p-1]) arr iup = A.map (size v - 1 -) . prescanr (+) 0 $ flags index = A.zipWith deal flags (A.zip idown iup) in - permute const v (\ix -> Just_ (index1 (index!ix))) v + permute const v (\ix -> Just (index1 (index!ix))) v -- This is rather slow. Speeding up the reference implementation by using, say, diff --git a/src/Data/Array/Accelerate/Test/Similar.hs b/src/Data/Array/Accelerate/Test/Similar.hs index 1017408cc..1ca864506 100644 --- a/src/Data/Array/Accelerate/Test/Similar.hs +++ b/src/Data/Array/Accelerate/Test/Similar.hs @@ -64,33 +64,23 @@ instance Similar Int8 instance Similar Int16 instance Similar Int32 instance Similar Int64 +instance Similar Int128 instance Similar Word8 instance Similar Word16 instance Similar Word32 instance Similar Word64 +instance Similar Word128 instance Similar Char instance Similar Bool -instance Similar CShort -instance Similar CUShort -instance Similar CInt -instance Similar CUInt -instance Similar CLong -instance Similar CULong -instance Similar CLLong -instance Similar CULLong -instance Similar CChar -instance Similar CSChar -instance Similar CUChar instance Similar (Any Z) instance (Eq sh, Eq sz) => Similar (sh:.sz) instance (Eq sh) => Similar (Any (sh:.Int)) -instance Similar Half where (~=) = absRelTol 0.05 0.5 -instance Similar Float where (~=) = absRelTol 0.00005 0.005 -instance Similar Double where (~=) = absRelTol 0.00005 0.005 -instance Similar CFloat where (~=) = absRelTol 0.00005 0.005 -instance Similar CDouble where (~=) = absRelTol 0.00005 0.005 +instance Similar Float16 where (~=) = absRelTol 0.05 0.5 +instance Similar Float32 where (~=) = absRelTol 0.00005 0.005 +instance Similar Float64 where (~=) = absRelTol 0.000005 0.0005 +instance Similar Float128 where (~=) = absRelTol 0.0000005 0.0005 instance (Similar a, Similar b) => Similar (a, b) where (x1, x2) ~= (y1, y2) = x1 ~= y1 && x2 ~= y2 diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index 4c0fbd397..270862dcd 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -32,19 +32,19 @@ import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Pretty.Print ( primOperator, isInfix, opName ) import Data.Array.Accelerate.Trafo.Environment -import Data.Array.Accelerate.Type import qualified Data.Array.Accelerate.Debug.Internal.Stats as Stats -import Data.Bits +-- import Data.Bits import Data.Monoid +-- import Data.Primitive.Vec import Data.Text ( Text ) import Prettyprinter import Prettyprinter.Render.Text import Prelude hiding ( exp ) -import qualified Prelude as P +-- import qualified Prelude as P -import GHC.Float ( float2Double, double2Float ) +-- import GHC.Float ( float2Double, double2Float ) -- Propagate constant expressions, which are either constant valued expressions @@ -60,7 +60,6 @@ propagate env = cvtE cvtE :: OpenExp env aenv e -> Maybe e cvtE exp = case exp of Const _ c -> Just c - PrimConst c -> Just (evalPrimConst c) Evar (Var _ ix) | e <- prjExp ix env , Nothing <- matchOpenExp exp e -> cvtE e @@ -86,6 +85,8 @@ evalPrimApp env f x | otherwise = maybe (Any False, PrimApp f x) (Any True,) $ case f of + _ -> Nothing + {-- PrimAdd ty -> evalAdd ty x env PrimSub ty -> evalSub ty x env PrimMul ty -> evalMul ty x env @@ -148,6 +149,7 @@ evalPrimApp env f x PrimLNot -> evalLNot x env PrimFromIntegral ta tb -> evalFromIntegral ta tb x env PrimToFloating ta tb -> evalToFloating ta tb x env + --} -- Discriminate binary functions that commute, and if so return the operands in @@ -249,6 +251,7 @@ associates fun exp = case fun of -- Helper functions -- ---------------- +{-- type a :-> b = forall env aenv. OpenExp env aenv a -> Gamma env env aenv -> Maybe (OpenExp env aenv b) eval1 :: SingleType b -> (a -> b) -> a :-> b @@ -297,6 +300,7 @@ untup2 :: OpenExp env aenv (a, b) -> Maybe (OpenExp env aenv a, OpenExp env aenv untup2 exp | Pair a b <- exp = Just (a, b) | otherwise = Nothing +--} pprFun :: Text -> PrimFun f -> Text @@ -310,7 +314,7 @@ pprFun rule f then parens (opName op) else opName op - +{-- -- Methods of Num -- -------------- @@ -723,22 +727,5 @@ evalToFloating (FloatingNumType ta) tb x env | FloatingDict <- floatingDict ta , FloatingDict <- floatingDict tb = eval1 (NumSingleType $ FloatingNumType tb) realToFrac x env - - --- Scalar primitives --- ----------------- - -evalPrimConst :: PrimConst a -> a -evalPrimConst (PrimMinBound ty) = evalMinBound ty -evalPrimConst (PrimMaxBound ty) = evalMaxBound ty -evalPrimConst (PrimPi ty) = evalPi ty - -evalMinBound :: BoundedType a -> a -evalMinBound (IntegralBoundedType ty) | IntegralDict <- integralDict ty = minBound - -evalMaxBound :: BoundedType a -> a -evalMaxBound (IntegralBoundedType ty) | IntegralDict <- integralDict ty = maxBound - -evalPi :: FloatingType a -> a -evalPi ty | FloatingDict <- floatingDict ty = pi +--} diff --git a/src/Data/Array/Accelerate/Trafo/Delayed.hs b/src/Data/Array/Accelerate/Trafo/Delayed.hs index 045d205e0..48d7d44cb 100644 --- a/src/Data/Array/Accelerate/Trafo/Delayed.hs +++ b/src/Data/Array/Accelerate/Trafo/Delayed.hs @@ -30,6 +30,7 @@ import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Trafo.Substitution +import Data.Array.Accelerate.Type import Data.Array.Accelerate.Debug.Internal.Stats as Stats @@ -56,7 +57,7 @@ data DelayedOpenAcc aenv a where { reprD :: ArrayR (Array sh e) , extentD :: Exp aenv sh , indexD :: Fun aenv (sh -> e) - , linearIndexD :: Fun aenv (Int -> e) + , linearIndexD :: Fun aenv (INT -> e) } -> DelayedOpenAcc aenv (Array sh e) instance HasArraysR DelayedOpenAcc where diff --git a/src/Data/Array/Accelerate/Trafo/Fusion.hs b/src/Data/Array/Accelerate/Trafo/Fusion.hs index 81487be09..d4bcf3487 100644 --- a/src/Data/Array/Accelerate/Trafo/Fusion.hs +++ b/src/Data/Array/Accelerate/Trafo/Fusion.hs @@ -178,6 +178,7 @@ manifest config (OpenAcc pacc) = Apair a1 a2 -> Apair (manifest config a1) (manifest config a2) Anil -> Anil Atrace msg a1 a2 -> Atrace msg (manifest config a1) (manifest config a2) + Acoerce scale bR a -> Acoerce scale bR (manifest config a) Apply repr f a -> apply repr (cvtAF f) (manifest config a) Aforeign repr ff f a -> Aforeign repr ff (cvtAF f) (manifest config a) @@ -364,12 +365,13 @@ embedPreOpenAcc config matchAcc embedAcc elimAcc pacc -- duplication. SEE: [Sharing vs. Fusion] -- Alet lhs bnd body -> aletD embedAcc elimAcc lhs bnd body - Anil -> done $ Anil Acond p at ae -> acondD matchAcc embedAcc (cvtE p) at ae + Anil -> done $ Anil Apply aR f a -> done $ Apply aR (cvtAF f) (cvtA a) Awhile p f a -> done $ Awhile (cvtAF p) (cvtAF f) (cvtA a) Apair a1 a2 -> done $ Apair (cvtA a1) (cvtA a2) Atrace msg a1 a2 -> done $ Atrace msg (cvtA a1) (cvtA a2) + Acoerce scale bR a -> done $ Acoerce scale bR (cvtA a) Aforeign aR ff f a -> done $ Aforeign aR ff (cvtAF f) (cvtA a) -- Collect s -> collectD s @@ -1472,19 +1474,20 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en Undef tR -> Undef tR Nil -> Nil Pair e1 e2 -> Pair (cvtE e1) (cvtE e2) - VecPack vR e -> VecPack vR (cvtE e) - VecUnpack vR e -> VecUnpack vR (cvtE e) + Extract vR iR v i -> Extract vR iR (cvtE v) (cvtE i) + Insert vR iR v i x -> Insert vR iR (cvtE v) (cvtE i) (cvtE x) + Shuffle vR iR x y i -> Shuffle vR iR (cvtE x) (cvtE y) (cvtE i) + Select eR m x y -> Select eR (cvtE m) (cvtE x) (cvtE y) IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh) IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl) ToIndex shR' sh ix -> ToIndex shR' (cvtE sh) (cvtE ix) FromIndex shR' sh i -> FromIndex shR' (cvtE sh) (cvtE i) - Case e rhs def -> Case (cvtE e) (over (mapped . _2) cvtE rhs) (fmap cvtE def) + Case eR e rhs def -> Case eR (cvtE e) (over (mapped . _2) cvtE rhs) (fmap cvtE def) Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e) - PrimConst c -> PrimConst c PrimApp g x -> PrimApp g (cvtE x) ShapeSize shR' sh -> ShapeSize shR' (cvtE sh) While p f x -> While (replaceF sh' f' avar p) (replaceF sh' f' avar f) (cvtE x) - Coerce t1 t2 e -> Coerce t1 t2 (cvtE e) + Bitcast t1 t2 e -> Bitcast t1 t2 (cvtE e) Shape a | Just Refl <- matchVar a avar -> Stats.substitution "replaceE/shape" sh' @@ -1500,10 +1503,10 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en , Lam lhs (Body b) <- f' -> Stats.substitution "replaceE/!!" . cvtE $ Let lhs - (Let (LeftHandSideSingle scalarTypeInt) i + (Let (LeftHandSideSingle scalarType) i $ FromIndex shR (weakenE (weakenSucc' weakenId) sh') $ Evar - $ Var scalarTypeInt ZeroIdx) + $ Var scalarType ZeroIdx) b | otherwise -> LinearIndex a (cvtE i) @@ -1548,6 +1551,7 @@ aletD' embedAcc elimAcc (LeftHandSideSingle ArrayR{}) (Embed env1 cc1) (Embed en Acond p at ae -> Acond (cvtE p) (cvtA at) (cvtA ae) Anil -> Anil Atrace msg a b -> Atrace msg (cvtA a) (cvtA b) + Acoerce scale bR a -> Acoerce scale bR (cvtA a) Apair a1 a2 -> Apair (cvtA a1) (cvtA a2) Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (cvtA a) Apply repr f a -> Apply repr (cvtAF f) (cvtA a) @@ -1673,23 +1677,23 @@ identity t | DeclareVars lhs _ value <- declareVars t = Lam lhs $ Body $ expVars $ value weakenId -toIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (sh -> Int) +toIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (sh -> INT) toIndex shR sh | DeclareVars lhs k value <- declareVars $ shapeType shR = Lam lhs $ Body $ ToIndex shR (weakenE k sh) $ expVars $ value weakenId -fromIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (Int -> sh) +fromIndex :: ShapeR sh -> OpenExp env aenv sh -> OpenFun env aenv (INT -> sh) fromIndex shR sh - = Lam (LeftHandSideSingle scalarTypeInt) + = Lam (LeftHandSideSingle scalarType) $ Body $ FromIndex shR (weakenE (weakenSucc' weakenId) sh) $ Evar - $ Var scalarTypeInt ZeroIdx + $ Var scalarType ZeroIdx intersect :: ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv sh -> OpenExp env aenv sh intersect = mkShapeBinary f where - f a b = PrimApp (PrimMin singleType) $ Pair a b + f a b = PrimApp (PrimMin scalarType) $ Pair a b -- union :: ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv sh -> OpenExp env aenv sh -- union = mkShapeBinary f @@ -1697,7 +1701,7 @@ intersect = mkShapeBinary f -- f a b = PrimApp (PrimMax singleType) $ Pair a b mkShapeBinary - :: (forall env'. OpenExp env' aenv Int -> OpenExp env' aenv Int -> OpenExp env' aenv Int) + :: (forall env'. OpenExp env' aenv INT -> OpenExp env' aenv INT -> OpenExp env' aenv INT) -> ShapeR sh -> OpenExp env aenv sh -> OpenExp env aenv sh @@ -1744,8 +1748,13 @@ indexArray v@(Var (ArrayR shR _) _) | DeclareVars lhs _ value <- declareVars $ shapeType shR = Lam lhs $ Body $ Index v $ expVars $ value weakenId -linearIndex :: ArrayVar aenv (Array sh e) -> Fun aenv (Int -> e) -linearIndex v = Lam (LeftHandSideSingle scalarTypeInt) $ Body $ LinearIndex v $ Evar $ Var scalarTypeInt ZeroIdx +linearIndex :: ArrayVar aenv (Array sh e) -> Fun aenv (INT -> e) +linearIndex v + = Lam (LeftHandSideSingle scalarType) + $ Body + $ LinearIndex v + $ Evar + $ Var scalarType ZeroIdx extractOpenAcc :: ExtractAcc OpenAcc diff --git a/src/Data/Array/Accelerate/Trafo/LetSplit.hs b/src/Data/Array/Accelerate/Trafo/LetSplit.hs index 9d3c1250e..efb3b94f5 100644 --- a/src/Data/Array/Accelerate/Trafo/LetSplit.hs +++ b/src/Data/Array/Accelerate/Trafo/LetSplit.hs @@ -39,6 +39,7 @@ convertPreOpenAcc = \case Apair a1 a2 -> Apair (convertAcc a1) (convertAcc a2) Anil -> Anil Atrace msg as bs -> Atrace msg (convertAcc as) (convertAcc bs) + Acoerce scale bR a -> Acoerce scale bR (convertAcc a) Apply repr f a -> Apply repr (convertAfun f) (convertAcc a) Aforeign repr asm f a -> Aforeign repr asm (convertAfun f) (convertAcc a) Acond e a1 a2 -> Acond e (convertAcc a1) (convertAcc a2) diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index d122bfc30..e8bca9635 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -337,6 +337,7 @@ convertSharingAcc config alyt aenv (ScopedAcc lams (AccSharing _ preAcc)) Aprj ix a -> let AST.OpenAcc a' = cvtAprj ix a in a' Atrace msg acc1 acc2 -> AST.Atrace msg (cvtA acc1) (cvtA acc2) + Acoerce scale bR acc -> AST.Acoerce scale bR (cvtA acc) Use repr array -> AST.Use repr array Unit tp e -> AST.Unit tp (cvtE e) Generate repr@(ArrayR shr _) sh f @@ -758,21 +759,22 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp Prj idx e -> cvtPrj idx (cvt e) Nil -> AST.Nil Pair e1 e2 -> AST.Pair (cvt e1) (cvt e2) - VecPack vec e -> AST.VecPack vec (cvt e) - VecUnpack vec e -> AST.VecUnpack vec (cvt e) + Extract vR iR v i -> AST.Extract vR iR (cvt v) (cvt i) + Insert vR iR v i x -> AST.Insert vR iR (cvt v) (cvt i) (cvt x) + Shuffle eR iR x y i -> AST.Shuffle eR iR (cvt x) (cvt y) (cvt i) + Select eR m x y -> AST.Select eR (cvt m) (cvt x) (cvt y) ToIndex shr sh ix -> AST.ToIndex shr (cvt sh) (cvt ix) FromIndex shr sh e -> AST.FromIndex shr (cvt sh) (cvt e) Case e rhs -> cvtCase (cvt e) (over (mapped . _2) cvt rhs) Cond e1 e2 e3 -> AST.Cond (cvt e1) (cvt e2) (cvt e3) While tp p it i -> AST.While (cvtFun1 tp p) (cvtFun1 tp it) (cvt i) - PrimConst c -> AST.PrimConst c PrimApp f e -> cvtPrimFun f (cvt e) Index _ a e -> AST.Index (cvtAvar a) (cvt e) LinearIndex _ a i -> AST.LinearIndex (cvtAvar a) (cvt i) Shape _ a -> AST.Shape (cvtAvar a) ShapeSize shr e -> AST.ShapeSize shr (cvt e) Foreign repr ff f e -> AST.Foreign repr ff (convertSmartFun config (typeR e) f) (cvt e) - Coerce t1 t2 e -> AST.Coerce t1 t2 (cvt e) + Bitcast t1 t2 e -> AST.Bitcast t1 t2 (cvt e) cvtPrj :: forall a b c env1 aenv1. PairIdx (a, b) c -> AST.OpenExp env1 aenv1 (a, b) -> AST.OpenExp env1 aenv1 c cvtPrj PairIdxLeft (AST.Pair a _) = a @@ -806,7 +808,7 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp AST.Let lhs bnd body -> AST.Let lhs bnd (cvtPrimFun f body) x -> AST.PrimApp f x - -- Convert the flat list of equations into nested case statement + -- Convert the flat list of equations into nested case statements -- directly on the tag variables. -- cvtCase :: HasCallStack => AST.OpenExp env' aenv' a -> [(TagR a, AST.OpenExp env' aenv' b)] -> AST.OpenExp env' aenv' b @@ -817,26 +819,43 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp = AST.Let lhs s $ nested (expVars (value weakenId)) (over (mapped . _2) (weakenE (weakenWithLHS lhs)) es) where nested :: HasCallStack => AST.OpenExp env' aenv' a -> [(TagR a, AST.OpenExp env' aenv' b)] -> AST.OpenExp env' aenv' b + nested _ [] = internalError "empty case" nested _ [(_,r)] = r - nested s rs = - let groups = groupBy (eqT `on` fst) rs - tags = map (firstT . fst . head) groups - e = prjT (fst (head rs)) s - rhs = map (nested s . map (over _1 ignore)) groups - in - AST.Case e (zip tags rhs) Nothing + nested s rs@(r:_) + | Exists tag <- tagT (fst r) + = let groups = groupBy (eqT `on` fst) rs + rhs = map (nested s . map (over _1 ignore)) groups + e = prjT tag (fst r) s + tags = map (firstT tag . fst . head) groups + in + AST.Case tag e (zip tags rhs) Nothing + + tagT :: TagR a -> Exists TagType + tagT = fromJust . go + where + go ::TagR a -> Maybe (Exists TagType) + go (TagRcon t _ _) = Just (Exists t) + go (TagRenum t _) = Just (Exists t) + go (TagRpair a b) + | Just t <- go a = Just t + | otherwise = go b + go _ = Nothing -- Extract the variable representing this particular tag from the -- scrutinee. This is safe because we let-bind the argument first. - prjT :: TagR a -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' TAG - prjT = fromJust $$ go + prjT :: forall a t env' aenv'. TagType t -> TagR a -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' t + prjT t = fromJust $$ go where - go :: TagR a -> AST.OpenExp env' aenv' a -> Maybe (AST.OpenExp env' aenv' TAG) - go TagRtag{} (AST.Pair l _) = Just l - go (TagRpair ta tb) (AST.Pair l r) = - case go ta l of - Just t -> Just t - Nothing -> go tb r + go :: TagR s -> AST.OpenExp env' aenv' s -> Maybe (AST.OpenExp env' aenv' t) + go (TagRcon s _ _) (AST.Pair l _) + | Just Refl <- matchTagType s t + = Just l + go (TagRenum s _) x + | Just Refl <- matchTagType s t + = Just x + go (TagRpair ta tb) (AST.Pair l r) + | Just x <- go ta l = Just x + | otherwise = go tb r go _ _ = Nothing -- Equality up to the first constructor tag encountered @@ -844,26 +863,37 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp eqT a b = snd $ go a b where go :: TagR a -> TagR a -> (Any, Bool) - go TagRunit TagRunit = no True - go TagRsingle{} TagRsingle{} = no True - go TagRundef{} TagRundef{} = no True - go (TagRtag v1 _) (TagRtag v2 _) = yes (v1 == v2) - go (TagRpair a1 b1) (TagRpair a2 b2) = + go TagRunit TagRunit = no True + go TagRsingle{} TagRsingle{} = no True + go TagRundef{} TagRundef{} = no True + go (TagRenum t1 v1) (TagRenum t2 v2) + | Just Refl <- matchTagType t1 t2 + , TagDict <- tagDict t1 + = yes (v1 == v2) + go (TagRcon t1 v1 _) (TagRcon t2 v2 _) + | Just Refl <- matchTagType t1 t2 + , TagDict <- tagDict t1 + = yes (v1 == v2) + go (TagRpair a1 b1) (TagRpair a2 b2) = let (Any r, s) = go a1 a2 in case r of True -> yes s False -> go b1 b2 go _ _ = no False - firstT :: TagR a -> TAG - firstT = fromJust . go + firstT :: forall a t. TagType t -> TagR a -> t + firstT t = fromJust . go where - go :: TagR a -> Maybe TAG - go (TagRtag v _) = Just v - go (TagRpair a b) = - case go a of - Just t -> Just t - Nothing -> go b + go :: TagR s -> Maybe t + go (TagRcon s v _) + | Just Refl <- matchTagType s t + = Just v + go (TagRenum s v) + | Just Refl <- matchTagType s t + = Just v + go (TagRpair a b) + | Just v <- go a = Just v + | otherwise = go b go _ = Nothing -- Replace the first constructor tag encountered with a regular @@ -875,7 +905,16 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp go TagRunit = no $ TagRunit go (TagRsingle t) = no $ TagRsingle t go (TagRundef t) = no $ TagRundef t - go (TagRtag _ a) = yes $ TagRpair (TagRundef scalarType) a + go (TagRenum t _) = yes $ + case t of + TagBit -> TagRundef scalarType + TagWord8 -> TagRundef scalarType + TagWord16 -> TagRundef scalarType + go (TagRcon t _ a) = yes $ + case t of + TagBit -> TagRpair (TagRundef scalarType) a + TagWord8 -> TagRpair (TagRundef scalarType) a + TagWord16 -> TagRpair (TagRundef scalarType) a go (TagRpair a1 a2) = let (Any r, a1') = go a1 in case r of @@ -1240,7 +1279,7 @@ instance HasTypeR UnscopedExp where typeR (UnscopedExp _ exp) = Smart.typeR exp -- Specifies a scalar expression AST with sharing. For expressions rooted in functions the list --- holds a sorted environment corresponding to the variables bound in the immediate surounding +-- holds a sorted environment corresponding to the variables bound in the immediate surrounding -- lambdas. data ScopedExp t = ScopedExp [StableSharingExp] (SharingExp ScopedAcc ScopedExp t) @@ -1519,6 +1558,9 @@ makeOccMapSharingAcc config accOccMap = traverseAcc (a', h1) <- traverseAcc lvl acc1 (b', h2) <- traverseAcc lvl acc2 return (Atrace msg a' b', h1 `max` h2 + 1) + Acoerce scale bR acc -> do + (a', h) <- traverseAcc lvl acc + return (Acoerce scale bR a', h + 1) Use repr arr -> return (Use repr arr, 1) Unit tp e -> do (e', h) <- traverseExp lvl e @@ -1837,37 +1879,38 @@ makeOccMapSharingExp config accOccMap expOccMap = travE return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height) reconstruct $ case pexp of - Tag tp i -> return (Tag tp i, 0) -- height is 0! - Const tp c -> return (Const tp c, 1) - Undef tp -> return (Undef tp, 1) - Nil -> return (Nil, 1) - Pair e1 e2 -> travE2 Pair e1 e2 - Prj i e -> travE1 (Prj i) e - VecPack vec e -> travE1 (VecPack vec) e - VecUnpack vec e -> travE1 (VecUnpack vec) e - ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix - FromIndex shr sh e -> travE2 (FromIndex shr) sh e - Match t e -> travE1 (Match t) e - Case e rhs -> do - (e', h1) <- travE lvl e - (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ] - return (Case e' rhs', h1 `max` maximum h2 + 1) - Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 - While t p iter init -> do - (p' , h1) <- traverseFun1 lvl t p - (iter', h2) <- traverseFun1 lvl t iter - (init', h3) <- travE lvl init - return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) - PrimConst c -> return (PrimConst c, 1) - PrimApp p e -> travE1 (PrimApp p) e - Index tp a e -> travAE (Index tp) a e - LinearIndex tp a i -> travAE (LinearIndex tp) a i - Shape shr a -> travA (Shape shr) a - ShapeSize shr e -> travE1 (ShapeSize shr) e - Foreign tp ff f e -> do - (e', h) <- travE lvl e - return (Foreign tp ff f e', h+1) - Coerce t1 t2 e -> travE1 (Coerce t1 t2) e + Tag tp i -> return (Tag tp i, 0) -- height is 0! + Const tp c -> return (Const tp c, 1) + Undef tp -> return (Undef tp, 1) + Nil -> return (Nil, 1) + Pair e1 e2 -> travE2 Pair e1 e2 + Prj i e -> travE1 (Prj i) e + Extract vR iR v i -> travE2 (Extract vR iR) v i + Insert vR iR v i x -> travE3 (Insert vR iR) v i x + Shuffle eR iR x y i -> travE3 (Shuffle eR iR) x y i + Select eR m x y -> travE3 (Select eR) m x y + ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix + FromIndex shr sh e -> travE2 (FromIndex shr) sh e + Match t e -> travE1 (Match t) e + Case e rhs -> do + (e', h1) <- travE lvl e + (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ] + return (Case e' rhs', h1 `max` maximum h2 + 1) + Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 + While t p iter init -> do + (p' , h1) <- traverseFun1 lvl t p + (iter', h2) <- traverseFun1 lvl t iter + (init', h3) <- travE lvl init + return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) + PrimApp p e -> travE1 (PrimApp p) e + Index tp a e -> travAE (Index tp) a e + LinearIndex tp a i -> travAE (LinearIndex tp) a i + Shape shr a -> travA (Shape shr) a + ShapeSize shr e -> travE1 (ShapeSize shr) e + Foreign tp ff f e -> do + (e', h) <- travE lvl e + return (Foreign tp ff f e', h+1) + Bitcast t1 t2 e -> travE1 (Bitcast t1 t2) e where traverseAcc :: HasCallStack => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) @@ -2382,7 +2425,11 @@ determineScopesSharingAcc config accOccMap = scopesAcc (a1', accCount1) = scopesAcc a1 (a2', accCount2) = scopesAcc a2 in - reconstruct (Atrace msg a1' a2') (accCount1 +++ accCount2) + reconstruct (Atrace msg a1' a2') (accCount1 +++ accCount2) + Acoerce scale bR acc -> let + (acc', accCount) = scopesAcc acc + in + reconstruct (Acoerce scale bR acc') accCount Use repr arr -> reconstruct (Use repr arr) noNodeCounts Unit tp e -> let (e', accCount) = scopesExp e @@ -2614,12 +2661,14 @@ determineScopesSharingAcc config accOccMap = scopesAcc :: HasCallStack => (SmartAcc a1 -> UnscopedAcc a2) -> (SmartAcc a1 -> ScopedAcc a2, NodeCounts) - scopesAfun1 f = (const (ScopedAcc ssa body'), (counts', graph)) + scopesAfun1 f + | not (null env) = internalError "unexpected unbound variables" + | otherwise = (const (ScopedAcc ssa body'), (counts', graph)) where - body@(UnscopedAcc fvs _) = f undefined - (ScopedAcc [] body', (counts,graph)) = scopesAcc body - (freeCounts, counts') = partition isBoundHere counts - ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] + body@(UnscopedAcc fvs _) = f undefined + (ScopedAcc env body', (counts,graph)) = scopesAcc body + (freeCounts, counts') = partition isBoundHere counts + ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _ i))) _) = i `elem` fvs isBoundHere _ = False @@ -2692,14 +2741,19 @@ determineScopesExp -> RootExp t -> (ScopedExp t, NodeCounts) -- Root (closed) expression plus Acc node counts determineScopesExp config accOccMap (RootExp expOccMap exp@(UnscopedExp fvs _)) - = let - (ScopedExp [] expWithScopes, (nodeCounts,graph)) = determineScopesSharingExp config accOccMap expOccMap exp - (expCounts, accCounts) = partition isExpNodeCount nodeCounts + | not (null env) + = internalError "unexpected unbound variables" + -- + | otherwise + = ( ScopedExp (buildInitialEnvExp fvs [se | ExpNodeCount se _ <- expCounts]) expWithScopes + , cleanCounts (accCounts,graph) + ) + where + (ScopedExp env expWithScopes, (nodeCounts,graph)) = determineScopesSharingExp config accOccMap expOccMap exp + (expCounts, accCounts) = partition isExpNodeCount nodeCounts - isExpNodeCount ExpNodeCount{} = True - isExpNodeCount _ = False - in - (ScopedExp (buildInitialEnvExp fvs [se | ExpNodeCount se _ <- expCounts]) expWithScopes, cleanCounts (accCounts,graph)) + isExpNodeCount ExpNodeCount{} = True + isExpNodeCount _ = False determineScopesSharingExp @@ -2721,12 +2775,18 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp :: HasCallStack => (SmartExp a -> UnscopedExp b) -> (SmartExp a -> ScopedExp b, NodeCounts) - scopesFun1 f = tracePure (bformat ("LAMBDA " % list formatStableSharingExp) ssa) (bformat (list formatNodeCount) counts) (const (ScopedExp ssa body'), (counts',graph)) + scopesFun1 f + | not (null env) + = internalError "unexpected unbound variables" + -- + | otherwise + = tracePure (bformat ("LAMBDA " % list formatStableSharingExp) ssa) (bformat (list formatNodeCount) counts) + $ (const (ScopedExp ssa body'), (counts',graph)) where - body@(UnscopedExp fvs _) = f undefined - (ScopedExp [] body', (counts, graph)) = scopesExp body - (freeCounts, counts') = partition isBoundHere counts - ssa = buildInitialEnvExp fvs [se | ExpNodeCount se _ <- freeCounts] + body@(UnscopedExp fvs _) = f undefined + (ScopedExp env body', (counts, graph)) = scopesExp body + (freeCounts, counts') = partition isBoundHere counts + ssa = buildInitialEnvExp fvs [se | ExpNodeCount se _ <- freeCounts] isBoundHere (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _ i))) _) = i `elem` fvs isBoundHere _ = False @@ -2749,8 +2809,10 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp Pair e1 e2 -> travE2 Pair e1 e2 Nil -> reconstruct Nil noNodeCounts Prj i e -> travE1 (Prj i) e - VecPack vec e -> travE1 (VecPack vec) e - VecUnpack vec e -> travE1 (VecUnpack vec) e + Extract vR iR v i -> travE2 (Extract vR iR) v i + Insert vR iR v i x -> travE3 (Insert vR iR) v i x + Shuffle eR iR x y i -> travE3 (Shuffle eR iR) x y i + Select eR m x y -> travE3 (Select eR) m x y ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix FromIndex shr sh e -> travE2 (FromIndex shr) sh e Match t e -> travE1 (Match t) e @@ -2762,14 +2824,13 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp (it', accCount2) = scopesFun1 it (i' , accCount3) = scopesExp i in reconstruct (While tp p' it' i') (accCount1 +++ accCount2 +++ accCount3) - PrimConst c -> reconstruct (PrimConst c) noNodeCounts PrimApp p e -> travE1 (PrimApp p) e Index tp a e -> travAE (Index tp) a e LinearIndex tp a e -> travAE (LinearIndex tp) a e Shape shr a -> travA (Shape shr) a ShapeSize shr e -> travE1 (ShapeSize shr) e Foreign tp ff f e -> travE1 (Foreign tp ff f) e - Coerce t1 t2 e -> travE1 (Coerce t1 t2) e + Bitcast t1 t2 e -> travE1 (Bitcast t1 t2) e where travE1 :: HasCallStack => (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t) @@ -3146,6 +3207,18 @@ recoverSharingSeq config seq --} +-- Utilities +-- --------- + +data TagDict t where + TagDict :: Eq t => TagDict t + +tagDict :: TagType t -> TagDict t +tagDict TagBit = TagDict +tagDict TagWord8 = TagDict +tagDict TagWord16 = TagDict + + -- Debugging -- --------- diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index 574747865..f76013a86 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -1,11 +1,12 @@ -{-# LANGUAGE TupleSections #-} {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} @@ -208,18 +209,19 @@ strengthenShrunkLHS -> env1 :?> env1' -> env2 :?> env2' strengthenShrunkLHS (LeftHandSideWildcard _) (LeftHandSideWildcard _) k = k -strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = \ix -> case ix of +strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideSingle _) k = \case ZeroIdx -> Just ZeroIdx SuccIdx ix' -> SuccIdx <$> k ix' -strengthenShrunkLHS (LeftHandSidePair lA hA) (LeftHandSidePair lB hB) k = strengthenShrunkLHS hA hB $ strengthenShrunkLHS lA lB k -strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideWildcard _) k = \ix -> case ix of +strengthenShrunkLHS (LeftHandSidePair lA hA) (LeftHandSidePair lB hB) k = + strengthenShrunkLHS hA hB $ strengthenShrunkLHS lA lB k +strengthenShrunkLHS (LeftHandSideSingle _) (LeftHandSideWildcard _) k = \case ZeroIdx -> Nothing SuccIdx ix' -> k ix' -strengthenShrunkLHS (LeftHandSidePair l h) (LeftHandSideWildcard t) k = strengthenShrunkLHS h (LeftHandSideWildcard t2) $ strengthenShrunkLHS l (LeftHandSideWildcard t1) k - where - TupRpair t1 t2 = t -strengthenShrunkLHS (LeftHandSideWildcard _) _ _ = internalError "Second LHS defines more variables" -strengthenShrunkLHS _ _ _ = internalError "Mismatch LHS single with LHS pair" +strengthenShrunkLHS (LeftHandSidePair l h) (LeftHandSideWildcard (TupRpair t1 t2)) k + = strengthenShrunkLHS h (LeftHandSideWildcard t2) + $ strengthenShrunkLHS l (LeftHandSideWildcard t1) k +strengthenShrunkLHS (LeftHandSideWildcard _) _ _ = internalError "Second LHS defines more variables" +strengthenShrunkLHS _ _ _ = internalError "Mismatch LHS single with LHS pair" -- Shrinking @@ -240,14 +242,13 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE lIMIT = 1 cheap :: OpenExp env aenv t -> Bool - cheap (Evar _) = True - cheap (Pair e1 e2) = cheap e1 && cheap e2 - cheap Nil = True - cheap Const{} = True - cheap PrimConst{} = True - cheap Undef{} = True - cheap (Coerce _ _ e) = cheap e - cheap _ = False + cheap (Evar _) = True + cheap (Pair e1 e2) = cheap e1 && cheap e2 + cheap Nil = True + cheap Const{} = True + cheap Undef{} = True + cheap (Bitcast _ _ e) = cheap e + cheap _ = False shrinkE :: HasCallStack => OpenExp env aenv t -> (Any, OpenExp env aenv t) shrinkE exp = case exp of @@ -291,23 +292,24 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Undef t -> pure (Undef t) Nil -> pure Nil Pair x y -> Pair <$> shrinkE x <*> shrinkE y - VecPack vec e -> VecPack vec <$> shrinkE e - VecUnpack vec e -> VecUnpack vec <$> shrinkE e + Extract vR iR v i -> Extract vR iR <$> shrinkE v <*> shrinkE i + Insert vR iR v i x -> Insert vR iR <$> shrinkE v <*> shrinkE i <*> shrinkE x + Shuffle eR iR x y i -> Shuffle eR iR <$> shrinkE x <*> shrinkE y <*> shrinkE i + Select eR m x y -> Select eR <$> shrinkE m <*> shrinkE x <*> shrinkE y IndexSlice x ix sh -> IndexSlice x <$> shrinkE ix <*> shrinkE sh IndexFull x ix sl -> IndexFull x <$> shrinkE ix <*> shrinkE sl ToIndex shr sh ix -> ToIndex shr <$> shrinkE sh <*> shrinkE ix FromIndex shr sh i -> FromIndex shr <$> shrinkE sh <*> shrinkE i - Case e rhs def -> Case <$> shrinkE e <*> sequenceA [ (t,) <$> shrinkE c | (t,c) <- rhs ] <*> shrinkMaybeE def + Case eR e rhs def -> Case eR <$> shrinkE e <*> sequenceA [ (t,) <$> shrinkE c | (t,c) <- rhs ] <*> shrinkMaybeE def Cond p t e -> Cond <$> shrinkE p <*> shrinkE t <*> shrinkE e While p f x -> While <$> shrinkF p <*> shrinkF f <*> shrinkE x - PrimConst c -> pure (PrimConst c) PrimApp f x -> PrimApp f <$> shrinkE x Index a sh -> Index a <$> shrinkE sh LinearIndex a i -> LinearIndex a <$> shrinkE i Shape a -> pure (Shape a) ShapeSize shr sh -> ShapeSize shr <$> shrinkE sh Foreign repr ff f e -> Foreign repr ff <$> shrinkF f <*> shrinkE e - Coerce t1 t2 e -> Coerce t1 t2 <$> shrinkE e + Bitcast t1 t2 e -> Bitcast t1 t2 <$> shrinkE e shrinkF :: HasCallStack => OpenFun env aenv t -> (Any, OpenFun env aenv t) shrinkF = first Any . shrinkFun @@ -447,7 +449,6 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA FromIndex sh i -> FromIndex (shrinkE sh) (shrinkE i) Cond p t e -> Cond (shrinkE p) (shrinkE t) (shrinkE e) While p f x -> While (shrinkF p) (shrinkF f) (shrinkE x) - PrimConst c -> PrimConst c PrimApp f x -> PrimApp f (shrinkE x) Index a sh -> Index (shrinkAcc a) (shrinkE sh) LinearIndex a i -> LinearIndex (shrinkAcc a) (shrinkE i) @@ -456,7 +457,7 @@ shrinkPreAcc shrinkAcc reduceAcc = Stats.substitution "shrinkA" shrinkA Intersect sh sz -> Intersect (shrinkE sh) (shrinkE sz) Union sh sz -> Union (shrinkE sh) (shrinkE sz) Foreign ff f e -> Foreign ff (shrinkF f) (shrinkE e) - Coerce e -> Coerce (shrinkE e) + Bitcast e -> Bitcast (shrinkE e) shrinkF :: OpenFun env aenv' f -> OpenFun env aenv' f shrinkF (Lam f) = Lam (shrinkF f) @@ -492,23 +493,24 @@ usesOfExp range = countE Undef _ -> Finite 0 Nil -> Finite 0 Pair e1 e2 -> countE e1 <> countE e2 - VecPack _ e -> countE e - VecUnpack _ e -> countE e + Extract _ _ v i -> countE v <> countE i + Insert _ _ v i x -> countE v <> countE i <> countE x + Shuffle _ _ x y i -> countE x <> countE y <> countE i + Select _ m x y -> countE m <> countE x <> countE y IndexSlice _ ix sh -> countE ix <> countE sh IndexFull _ ix sl -> countE ix <> countE sl FromIndex _ sh i -> countE sh <> countE i ToIndex _ sh e -> countE sh <> countE e - Case e rhs def -> countE e <> mconcat [ countE c | (_,c) <- rhs ] <> maybe (Finite 0) countE def + Case _ e rhs def -> countE e <> mconcat [ countE c | (_,c) <- rhs ] <> maybe (Finite 0) countE def Cond p t e -> countE p <> countE t <> countE e While p f x -> countE x <> loopCount (usesOfFun range p) <> loopCount (usesOfFun range f) - PrimConst _ -> Finite 0 PrimApp _ x -> countE x Index _ sh -> countE sh LinearIndex _ i -> countE i Shape _ -> Finite 0 ShapeSize _ sh -> countE sh Foreign _ _ _ e -> countE e - Coerce _ _ e -> countE e + Bitcast _ _ e -> countE e usesOfFun :: VarsRange env -> OpenFun env aenv f -> Count usesOfFun range (Lam lhs f) = usesOfFun (weakenVarsRange lhs range) f @@ -544,6 +546,7 @@ usesOfPreAcc withShape countAcc idx = count Apair a1 a2 -> countA a1 + countA a2 Anil -> 0 Atrace _ a1 a2 -> countA a1 + countA a2 + Acoerce _ _ a -> countA a Apply _ f a -> countAF f idx + countA a Aforeign _ _ _ a -> countA a Acond p t e -> countE p + countA t + countA e @@ -579,16 +582,17 @@ usesOfPreAcc withShape countAcc idx = count Undef _ -> 0 Nil -> 0 Pair x y -> countE x + countE y - VecPack _ e -> countE e - VecUnpack _ e -> countE e + Extract _ _ v i -> countE v + countE i + Insert _ _ v i x -> countE v + countE i + countE x + Shuffle _ _ x y i -> countE x + countE y + countE i + Select _ m x y -> countE m + countE x + countE y IndexSlice _ ix sh -> countE ix + countE sh IndexFull _ ix sl -> countE ix + countE sl ToIndex _ sh ix -> countE sh + countE ix FromIndex _ sh i -> countE sh + countE i - Case e rhs def -> countE e + sum [ countE c | (_,c) <- rhs ] + maybe 0 countE def + Case _ e rhs def -> countE e + sum [ countE c | (_,c) <- rhs ] + maybe 0 countE def Cond p t e -> countE p + countE t + countE e While p f x -> countF p + countF f + countE x - PrimConst _ -> 0 PrimApp _ x -> countE x Index a sh -> countAvar a + countE sh LinearIndex a i -> countAvar a + countE i @@ -597,7 +601,7 @@ usesOfPreAcc withShape countAcc idx = count | withShape -> countAvar a | otherwise -> 0 Foreign _ _ _ e -> countE e - Coerce _ _ e -> countE e + Bitcast _ _ e -> countE e countME :: Maybe (OpenExp env aenv e) -> Int countME = maybe 0 countE diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index 6b0e8aa2c..7bf19b44a 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -91,6 +91,7 @@ simplifyPreOpenAcc = \case Apair a1 a2 -> Apair (simplifyOpenAcc a1) (simplifyOpenAcc a2) Anil -> Anil Atrace msg as bs -> Atrace msg (simplifyOpenAcc as) (simplifyOpenAcc bs) + Acoerce scale bR a -> Acoerce scale bR (simplifyOpenAcc a) Apply repr f a -> Apply repr (simplifyOpenAfun f) (simplifyOpenAcc a) Aforeign repr asm f a -> Aforeign repr asm (simplifyOpenAfun f) (simplifyOpenAcc a) Acond e a1 a2 -> Acond (simplifyExp e) (simplifyOpenAcc a1) (simplifyOpenAcc a2) @@ -276,15 +277,16 @@ simplifyOpenExp env = first getAny . cvtE Undef tp -> pure $ Undef tp Nil -> pure Nil Pair e1 e2 -> Pair <$> cvtE e1 <*> cvtE e2 - VecPack vec e -> VecPack vec <$> cvtE e - VecUnpack vec e -> VecUnpack vec <$> cvtE e + Extract vR iR v i -> Extract vR iR <$> cvtE v <*> cvtE i + Insert vR iR v i x -> Insert vR iR <$> cvtE v <*> cvtE i <*> cvtE x + Shuffle eR iR x y i -> Shuffle eR iR <$> cvtE x <*> cvtE y <*> cvtE i + Select eR m x y -> Select eR <$> cvtE m <*> cvtE x <*> cvtE y IndexSlice x ix sh -> IndexSlice x <$> cvtE ix <*> cvtE sh IndexFull x ix sl -> IndexFull x <$> cvtE ix <*> cvtE sl ToIndex shr sh ix -> toIndex shr (cvtE sh) (cvtE ix) FromIndex shr sh ix -> fromIndex shr (cvtE sh) (cvtE ix) - Case e rhs def -> caseof (cvtE e) (sequenceA [ (t,) <$> cvtE c | (t,c) <- rhs ]) (cvtMaybeE def) + Case eR e rhs def -> caseof eR (cvtE e) (sequenceA [ (t,) <$> cvtE c | (t,c) <- rhs ]) (cvtMaybeE def) Cond p t e -> cond (cvtE p) (cvtE t) (cvtE e) - PrimConst c -> pure $ PrimConst c PrimApp f x -> (u<>v, fx) where (u, x') = cvtE x @@ -295,7 +297,7 @@ simplifyOpenExp env = first getAny . cvtE ShapeSize shr sh -> shapeSize shr (cvtE sh) Foreign tp ff f e -> Foreign tp ff <$> first Any (simplifyOpenFun EmptyExp f) <*> cvtE e While p f x -> While <$> cvtF env p <*> cvtF env f <*> cvtE x - Coerce t1 t2 e -> Coerce t1 t2 <$> cvtE e + Bitcast t1 t2 e -> Bitcast t1 t2 <$> cvtE e cvtE' :: Gamma env' env' aenv -> OpenExp env' aenv e' -> (Any, OpenExp env' aenv e') cvtE' env' = first Any . simplifyOpenExp env' @@ -333,12 +335,14 @@ simplifyOpenExp env = first getAny . cvtE | Just Refl <- matchOpenExp t' e' = Stats.knownBranch "redundant" (yes e') | otherwise = Cond <$> p <*> t <*> e - caseof :: (Any, OpenExp env aenv TAG) - -> (Any, [(TAG, OpenExp env aenv b)]) + caseof :: TagType tag + -> (Any, OpenExp env aenv tag) + -> (Any, [(tag, OpenExp env aenv b)]) -> (Any, Maybe (OpenExp env aenv b)) -> (Any, OpenExp env aenv b) - caseof x@(_,x') xs@(_,xs') md@(_,md') + caseof tagR x@(_,x') xs@(_,xs') md@(_,md') | Const _ t <- x' + , TagDict <- tagDict tagR = Stats.caseElim "known" (yes (fromJust $ lookup t xs')) | Just d <- md' , [] <- xs' @@ -346,26 +350,27 @@ simplifyOpenExp env = first getAny . cvtE | Just d <- md' , [(_,(_,u))] <- us , Just Refl <- matchOpenExp d u - = Stats.caseDefault "merge" $ yes (Case x' (map snd vs) (Just u)) + = Stats.caseDefault "merge" $ yes (Case tagR x' (map snd vs) (Just u)) | Nothing <- md' , [] <- vs , [(_,(_,u))] <- us = Stats.caseElim "overlap" (yes u) | Nothing <- md' , [(_,(_,u))] <- us - = Stats.caseDefault "introduction" $ yes (Case x' (map snd vs) (Just u)) + = Stats.caseDefault "introduction" $ yes (Case tagR x' (map snd vs) (Just u)) | otherwise - = Case <$> x <*> xs <*> md + = Case tagR <$> x <*> xs <*> md where (us,vs) = partition (\(n,_) -> n > 1) $ Map.elems - . Map.fromListWith merge + . Map.fromListWith (merge tagR) $ [ (hashOpenExp e, (1,(t, e))) | (t,e) <- xs' ] - merge :: (Int, (TAG, OpenExp env aenv b)) -> (Int, (TAG, OpenExp env aenv b)) -> (Int, (TAG, OpenExp env aenv b)) - merge (n,(_,a)) (m,(_,b)) + merge :: TagType tag -> (Int, (tag, OpenExp env aenv b)) -> (Int, (tag, OpenExp env aenv b)) -> (Int, (tag, OpenExp env aenv b)) + merge t (n,(_,a)) (m,(_,b)) + | TagDict <- tagDict t = internalCheck "hashOpenExp/collision" (maybe False (const True) (matchOpenExp a b)) - $ (n+m, (0xff, a)) + $ (n+m, (maxBound, a)) -- Shape manipulations -- @@ -375,24 +380,24 @@ simplifyOpenExp env = first getAny . cvtE shape a = pure $ Shape a - shapeSize :: ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv Int) + shapeSize :: ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv INT) shapeSize shr (_, sh) | Just c <- extractConstTuple sh - = Stats.ruleFired "shapeSize/const" $ yes (Const scalarTypeInt (product (shapeToList shr c))) + = Stats.ruleFired "shapeSize/const" $ yes (Const (scalarType @INT) (product (shapeToList shr c))) shapeSize shr sh = ShapeSize shr <$> sh toIndex :: ShapeR sh -> (Any, OpenExp env aenv sh) -> (Any, OpenExp env aenv sh) - -> (Any, OpenExp env aenv Int) + -> (Any, OpenExp env aenv INT) toIndex _ (_,sh) (_,FromIndex _ sh' ix) | Just Refl <- matchOpenExp sh sh' = Stats.ruleFired "toIndex/fromIndex" $ yes ix toIndex shr sh ix = ToIndex shr <$> sh <*> ix fromIndex :: ShapeR sh -> (Any, OpenExp env aenv sh) - -> (Any, OpenExp env aenv Int) + -> (Any, OpenExp env aenv INT) -> (Any, OpenExp env aenv sh) fromIndex _ (_,sh) (_,ToIndex _ sh' ix) | Just Refl <- matchOpenExp sh sh' = Stats.ruleFired "fromIndex/toIndex" $ yes ix @@ -404,6 +409,7 @@ simplifyOpenExp env = first getAny . cvtE yes :: x -> (Any, x) yes x = (Any True, x) + extractConstTuple :: OpenExp env aenv t -> Maybe t extractConstTuple Nil = Just () extractConstTuple (Pair e1 e2) = (,) <$> extractConstTuple e1 <*> extractConstTuple e2 @@ -555,37 +561,22 @@ summariseOpenExp = (terms +~ 1) . goE travA :: acc aenv a -> Stats travA _ = zero & vars +~ 1 -- assume an array index, else we should have failed elsewhere - travC :: PrimConst c -> Stats - travC (PrimMinBound t) = travBoundedType t & terms +~ 1 - travC (PrimMaxBound t) = travBoundedType t & terms +~ 1 - travC (PrimPi t) = travFloatingType t & terms +~ 1 - travIntegralType :: IntegralType t -> Stats travIntegralType _ = zero & types +~ 1 travFloatingType :: FloatingType t -> Stats travFloatingType _ = zero & types +~ 1 + travBitType :: BitType t -> Stats + travBitType _ = zero & types +~ 1 + travNumType :: NumType t -> Stats travNumType (IntegralNumType t) = travIntegralType t & types +~ 1 travNumType (FloatingNumType t) = travFloatingType t & types +~ 1 - travBoundedType :: BoundedType t -> Stats - travBoundedType (IntegralBoundedType t) = travIntegralType t & types +~ 1 - - -- travScalarType :: ScalarType t -> Stats - -- travScalarType (SingleScalarType t) = travSingleType t & types +~ 1 - -- travScalarType (VectorScalarType t) = travVectorType t & types +~ 1 - - travSingleType :: SingleType t -> Stats - travSingleType (NumSingleType t) = travNumType t & types +~ 1 - - -- travVectorType :: VectorType t -> Stats - -- travVectorType (Vector2Type t) = travSingleType t & types +~ 1 - -- travVectorType (Vector3Type t) = travSingleType t & types +~ 1 - -- travVectorType (Vector4Type t) = travSingleType t & types +~ 1 - -- travVectorType (Vector8Type t) = travSingleType t & types +~ 1 - -- travVectorType (Vector16Type t) = travSingleType t & types +~ 1 + travScalarType :: ScalarType t -> Stats + travScalarType (NumScalarType t) = travNumType t & types +~ 1 + travScalarType (BitScalarType _) = zero & types +~ 1 -- The scrutinee has already been counted goE :: OpenExp env aenv t -> Stats @@ -597,23 +588,24 @@ summariseOpenExp = (terms +~ 1) . goE Const{} -> zero Undef _ -> zero Nil -> zero & terms +~ 1 - Pair e1 e2 -> travE e1 +++ travE e2 & terms +~ 1 - VecPack _ e -> travE e - VecUnpack _ e -> travE e + Pair e1 e2 -> travE e1 +++ travE e2 + Extract _ _ v i -> travE v +++ travE i + Insert _ _ v i x -> travE v +++ travE i +++ travE x + Shuffle _ _ x y i -> travE x +++ travE y +++ travE i + Select _ m x y -> travE m +++ travE x +++ travE y IndexSlice _ slix sh -> travE slix +++ travE sh & terms +~ 1 -- +1 for sliceIndex IndexFull _ slix sl -> travE slix +++ travE sl & terms +~ 1 -- +1 for sliceIndex ToIndex _ sh ix -> travE sh +++ travE ix FromIndex _ sh ix -> travE sh +++ travE ix - Case e rhs def -> travE e +++ mconcat [ travE c | (_,c) <- rhs ] +++ maybe zero travE def + Case _ e rhs def -> travE e +++ mconcat [ travE c | (_,c) <- rhs ] +++ maybe zero travE def Cond p t e -> travE p +++ travE t +++ travE e While p f x -> travF p +++ travF f +++ travE x - PrimConst c -> travC c Index a ix -> travA a +++ travE ix LinearIndex a ix -> travA a +++ travE ix Shape a -> travA a ShapeSize _ sh -> travE sh PrimApp f x -> travPrimFun f +++ travE x - Coerce _ _ e -> travE e + Bitcast _ _ e -> travE e travPrimFun :: PrimFun f -> Stats travPrimFun = (ops +~ 1) . goF @@ -621,66 +613,91 @@ summariseOpenExp = (terms +~ 1) . goE goF :: PrimFun f -> Stats goF fun = case fun of - PrimAdd t -> travNumType t - PrimSub t -> travNumType t - PrimMul t -> travNumType t - PrimNeg t -> travNumType t - PrimAbs t -> travNumType t - PrimSig t -> travNumType t - PrimQuot t -> travIntegralType t - PrimRem t -> travIntegralType t - PrimQuotRem t -> travIntegralType t - PrimIDiv t -> travIntegralType t - PrimMod t -> travIntegralType t - PrimDivMod t -> travIntegralType t - PrimBAnd t -> travIntegralType t - PrimBOr t -> travIntegralType t - PrimBXor t -> travIntegralType t - PrimBNot t -> travIntegralType t - PrimBShiftL t -> travIntegralType t - PrimBShiftR t -> travIntegralType t - PrimBRotateL t -> travIntegralType t - PrimBRotateR t -> travIntegralType t - PrimPopCount t -> travIntegralType t - PrimCountLeadingZeros t -> travIntegralType t + PrimAdd t -> travNumType t + PrimSub t -> travNumType t + PrimMul t -> travNumType t + PrimNeg t -> travNumType t + PrimAbs t -> travNumType t + PrimSig t -> travNumType t + PrimVAdd t -> travNumType t + PrimVMul t -> travNumType t + PrimQuot t -> travIntegralType t + PrimRem t -> travIntegralType t + PrimQuotRem t -> travIntegralType t + PrimIDiv t -> travIntegralType t + PrimMod t -> travIntegralType t + PrimDivMod t -> travIntegralType t + PrimBAnd t -> travIntegralType t + PrimBOr t -> travIntegralType t + PrimBXor t -> travIntegralType t + PrimBNot t -> travIntegralType t + PrimBShiftL t -> travIntegralType t + PrimBShiftR t -> travIntegralType t + PrimBRotateL t -> travIntegralType t + PrimBRotateR t -> travIntegralType t + PrimPopCount t -> travIntegralType t + PrimCountLeadingZeros t -> travIntegralType t PrimCountTrailingZeros t -> travIntegralType t - PrimFDiv t -> travFloatingType t - PrimRecip t -> travFloatingType t - PrimSin t -> travFloatingType t - PrimCos t -> travFloatingType t - PrimTan t -> travFloatingType t - PrimAsin t -> travFloatingType t - PrimAcos t -> travFloatingType t - PrimAtan t -> travFloatingType t - PrimSinh t -> travFloatingType t - PrimCosh t -> travFloatingType t - PrimTanh t -> travFloatingType t - PrimAsinh t -> travFloatingType t - PrimAcosh t -> travFloatingType t - PrimAtanh t -> travFloatingType t - PrimExpFloating t -> travFloatingType t - PrimSqrt t -> travFloatingType t - PrimLog t -> travFloatingType t - PrimFPow t -> travFloatingType t - PrimLogBase t -> travFloatingType t - PrimTruncate f i -> travFloatingType f +++ travIntegralType i - PrimRound f i -> travFloatingType f +++ travIntegralType i - PrimFloor f i -> travFloatingType f +++ travIntegralType i - PrimCeiling f i -> travFloatingType f +++ travIntegralType i - PrimIsNaN t -> travFloatingType t - PrimIsInfinite t -> travFloatingType t - PrimAtan2 t -> travFloatingType t - PrimLt t -> travSingleType t - PrimGt t -> travSingleType t - PrimLtEq t -> travSingleType t - PrimGtEq t -> travSingleType t - PrimEq t -> travSingleType t - PrimNEq t -> travSingleType t - PrimMax t -> travSingleType t - PrimMin t -> travSingleType t - PrimLAnd -> zero - PrimLOr -> zero - PrimLNot -> zero - PrimFromIntegral i n -> travIntegralType i +++ travNumType n - PrimToFloating n f -> travNumType n +++ travFloatingType f + PrimBReverse t -> travIntegralType t + PrimBSwap t -> travIntegralType t + PrimVBAnd t -> travIntegralType t + PrimVBOr t -> travIntegralType t + PrimVBXor t -> travIntegralType t + PrimFDiv t -> travFloatingType t + PrimRecip t -> travFloatingType t + PrimSin t -> travFloatingType t + PrimCos t -> travFloatingType t + PrimTan t -> travFloatingType t + PrimAsin t -> travFloatingType t + PrimAcos t -> travFloatingType t + PrimAtan t -> travFloatingType t + PrimSinh t -> travFloatingType t + PrimCosh t -> travFloatingType t + PrimTanh t -> travFloatingType t + PrimAsinh t -> travFloatingType t + PrimAcosh t -> travFloatingType t + PrimAtanh t -> travFloatingType t + PrimExpFloating t -> travFloatingType t + PrimSqrt t -> travFloatingType t + PrimLog t -> travFloatingType t + PrimFPow t -> travFloatingType t + PrimLogBase t -> travFloatingType t + PrimTruncate f i -> travFloatingType f +++ travIntegralType i + PrimRound f i -> travFloatingType f +++ travIntegralType i + PrimFloor f i -> travFloatingType f +++ travIntegralType i + PrimCeiling f i -> travFloatingType f +++ travIntegralType i + PrimIsNaN t -> travFloatingType t + PrimIsInfinite t -> travFloatingType t + PrimAtan2 t -> travFloatingType t + PrimLt t -> travScalarType t + PrimGt t -> travScalarType t + PrimLtEq t -> travScalarType t + PrimGtEq t -> travScalarType t + PrimEq t -> travScalarType t + PrimNEq t -> travScalarType t + PrimMin t -> travScalarType t + PrimMax t -> travScalarType t + PrimVMin t -> travScalarType t + PrimVMax t -> travScalarType t + PrimLAnd t -> travBitType t + PrimLOr t -> travBitType t + PrimLNot t -> travBitType t + PrimVLAnd t -> travBitType t + PrimVLOr t -> travBitType t + PrimFromIntegral i n -> travIntegralType i +++ travNumType n + PrimToFloating n f -> travNumType n +++ travFloatingType f + PrimToBool i b -> travIntegralType i +++ travBitType b + PrimFromBool b i -> travBitType b +++ travIntegralType i + + +-- Utilities +-- --------- + +data TagDict t where + TagDict :: (Eq t, Bounded t) => TagDict t + +tagDict :: TagType t -> TagDict t +tagDict TagBit = TagDict +tagDict TagWord8 = TagDict +tagDict TagWord16 = TagDict diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index e1aa1176b..8cb0a3512 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -149,29 +149,30 @@ inlineVars lhsBound expr bound substitute k1 k2 vars topExp = case topExp of Let lhs e1 e2 | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2 - Evar (Var t ix) -> Evar . Var t <$> k1 ix - Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 - Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 - Nil -> Just Nil - VecPack vec e1 -> VecPack vec <$> travE e1 - VecUnpack vec e1 -> VecUnpack vec <$> travE e1 - IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 - IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 - ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 - FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 - Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def - Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 - While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 - Const t c -> Just $ Const t c - PrimConst c -> Just $ PrimConst c - PrimApp p e1 -> PrimApp p <$> travE e1 - Index a e1 -> Index a <$> travE e1 - LinearIndex a e1 -> LinearIndex a <$> travE e1 - Shape a -> Just $ Shape a - ShapeSize shr e1 -> ShapeSize shr <$> travE e1 - Undef t -> Just $ Undef t - Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 + -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2 + Evar (Var t ix) -> Evar . Var t <$> k1 ix + Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 + Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 + Nil -> Just Nil + Extract vR iR v i -> Extract vR iR <$> travE v <*> travE i + Insert vR iR v i x -> Insert vR iR <$> travE v <*> travE i <*> travE x + Shuffle vR iR x y m -> Shuffle vR iR <$> travE x <*> travE y <*> travE m + Select eR m x y -> Select eR <$> travE m <*> travE x <*> travE y + IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 + IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 + ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 + FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 + Case eR e1 rhs def -> Case eR <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def + Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 + While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 + Const t c -> Just $ Const t c + PrimApp p e1 -> PrimApp p <$> travE e1 + Index a e1 -> Index a <$> travE e1 + LinearIndex a e1 -> LinearIndex a <$> travE e1 + Shape a -> Just $ Shape a + ShapeSize shr e1 -> ShapeSize shr <$> travE e1 + Undef t -> Just $ Undef t + Bitcast t1 t2 e1 -> Bitcast t1 t2 <$> travE e1 where travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s) @@ -546,31 +547,32 @@ rebuildOpenExp -> f (OpenExp env' aenv' t) rebuildOpenExp v av@(ReindexAvar reindex) exp = case exp of - Const t c -> pure $ Const t c - PrimConst c -> pure $ PrimConst c - Undef t -> pure $ Undef t - Evar var -> expOut <$> v var + Const t c -> pure $ Const t c + Undef t -> pure $ Undef t + Evar var -> expOut <$> v var Let lhs a b | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b - Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 - Nil -> pure Nil - VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e - VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e - IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh - IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl - ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def - Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e - While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x - PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x - Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh - LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i - Shape a -> Shape <$> reindex a - ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh - Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e - Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e + -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b + Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 + Nil -> pure Nil + Extract vR iR u i -> Extract vR iR <$> rebuildOpenExp v av u <*> rebuildOpenExp v av i + Insert vR iR u i x -> Insert vR iR <$> rebuildOpenExp v av u <*> rebuildOpenExp v av i <*> rebuildOpenExp v av x + Shuffle vR iR x y m -> Shuffle vR iR <$> rebuildOpenExp v av x <*> rebuildOpenExp v av y <*> rebuildOpenExp v av m + Select eR m x y -> Select eR <$> rebuildOpenExp v av m <*> rebuildOpenExp v av x <*> rebuildOpenExp v av y + IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh + IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl + ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + Case eR e rhs def -> Case eR <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def + Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e + While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x + PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x + Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh + LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i + Shape a -> Shape <$> reindex a + ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh + Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e + Bitcast t1 t2 e -> Bitcast t1 t2 <$> rebuildOpenExp v av e {-# INLINEABLE rebuildFun #-} rebuildFun @@ -681,6 +683,7 @@ rebuildPreOpenAcc k av acc = Apair as bs -> Apair <$> k av as <*> k av bs Anil -> pure Anil Atrace msg as bs -> Atrace msg <$> k av as <*> k av bs + Acoerce s bR a -> Acoerce s bR <$> k av a Apply repr f a -> Apply repr <$> rebuildAfun k av f <*> k av a Acond p t e -> Acond <$> rebuildOpenExp (pure . IE) av' p <*> k av t <*> k av e Awhile p f a -> Awhile <$> rebuildAfun k av p <*> rebuildAfun k av f <*> k av a diff --git a/src/Data/Array/Accelerate/Trafo/Var.hs b/src/Data/Array/Accelerate/Trafo/Var.hs index 76cb2b741..c6bb254f9 100644 --- a/src/Data/Array/Accelerate/Trafo/Var.hs +++ b/src/Data/Array/Accelerate/Trafo/Var.hs @@ -3,6 +3,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Trafo.Var -- Copyright : [2012..2020] The Accelerate Team @@ -33,7 +34,7 @@ data DeclareVars s t aenv where declareVars :: TupR s t -> DeclareVars s t env declareVars TupRunit - = DeclareVars LeftHandSideUnit weakenId $ const $ TupRunit + = DeclareVars LeftHandSideUnit weakenId $ const TupRunit declareVars (TupRsingle s) = DeclareVars (LeftHandSideSingle s) (weakenSucc weakenId) $ \k -> TupRsingle $ Var s $ k >:> ZeroIdx declareVars (TupRpair r1 r2) diff --git a/src/Data/Array/Accelerate/Type.hs b/src/Data/Array/Accelerate/Type.hs index 94e891cc1..3330a4f0e 100644 --- a/src/Data/Array/Accelerate/Type.hs +++ b/src/Data/Array/Accelerate/Type.hs @@ -27,494 +27,452 @@ -- Primitive scalar types supported by Accelerate -- -- Integral types: --- * Int -- * Int8 -- * Int16 -- * Int32 -- * Int64 --- * Word +-- * Int128 -- * Word8 -- * Word16 -- * Word32 -- * Word64 +-- * Word128 -- --- Floating types: +-- Floating types (IEEE): -- * Half -- * Float -- * Double +-- * Float128 -- --- SIMD vector types of the above: --- * Vec2 --- * Vec3 --- * Vec4 --- * Vec8 --- * Vec16 +-- A single bit -- --- Note that 'Int' has the same bit width as in plain Haskell computations. --- 'Float' and 'Double' represent IEEE single and double precision floating --- point numbers, respectively. +-- and SIMD vector types of all of the above -- module Data.Array.Accelerate.Type ( - Half(..), Float, Double, - module Data.Int, - module Data.Word, - module Foreign.C.Types, + Bit(..), Half(..), Float, Double, Float128(..), + module Data.Int, Int128(..), + module Data.Word, Word128(..), module Data.Array.Accelerate.Type, ) where import Data.Array.Accelerate.Orphans () -- Prim Half + +import Data.Primitive.Bit import Data.Primitive.Vec +import Data.Numeric.Float128 import Data.Bits import Data.Int -import Data.Primitive.Types +import Data.Kind import Data.Type.Equality +import Data.WideWord.Int128 +import Data.WideWord.Word128 import Data.Word -import Foreign.C.Types -import Foreign.Storable ( Storable ) import Formatting -import Language.Haskell.TH.Extra +import Language.Haskell.TH.Extra hiding ( Type ) import Numeric.Half import Text.Printf +import Unsafe.Coerce import GHC.Prim -import GHC.TypeLits +import GHC.TypeNats +type Float16 = Half +type Float32 = Float +type Float64 = Double -- Scalar types -- ------------ --- Reified dictionaries --- -data SingleDict a where - SingleDict :: ( Eq a, Ord a, Show a, Storable a, Prim a ) - => SingleDict a - -data IntegralDict a where - IntegralDict :: ( Eq a, Ord a, Show a - , Bounded a, Bits a, FiniteBits a, Integral a, Num a, Real a, Storable a ) - => IntegralDict a - -data FloatingDict a where - FloatingDict :: ( Eq a, Ord a, Show a - , Floating a, Fractional a, Num a, Real a, RealFrac a, RealFloat a, Storable a ) - => FloatingDict a - - --- Scalar type representation +-- | Scalar element types are values that can be stored in machine +-- registers: ground types (int32, float64, etc.) and SIMD vectors of these -- +data ScalarType a where + NumScalarType :: NumType a -> ScalarType a + BitScalarType :: BitType a -> ScalarType a + -- Void? --- | Integral types supported in array computations. --- -data IntegralType a where - TypeInt :: IntegralType Int - TypeInt8 :: IntegralType Int8 - TypeInt16 :: IntegralType Int16 - TypeInt32 :: IntegralType Int32 - TypeInt64 :: IntegralType Int64 - TypeWord :: IntegralType Word - TypeWord8 :: IntegralType Word8 - TypeWord16 :: IntegralType Word16 - TypeWord32 :: IntegralType Word32 - TypeWord64 :: IntegralType Word64 - --- | Floating-point types supported in array computations. --- -data FloatingType a where - TypeHalf :: FloatingType Half - TypeFloat :: FloatingType Float - TypeDouble :: FloatingType Double +data BitType a where + TypeBit :: BitType Bit + TypeMask :: KnownNat n => Proxy# n -> BitType (Vec n Bit) --- | Numeric element types implement Num & Real --- data NumType a where IntegralNumType :: IntegralType a -> NumType a FloatingNumType :: FloatingType a -> NumType a --- | Bounded element types implement Bounded --- -data BoundedType a where - IntegralBoundedType :: IntegralType a -> BoundedType a +data IntegralType a where + SingleIntegralType :: SingleIntegralType a -> IntegralType a + VectorIntegralType :: KnownNat n => Proxy# n -> SingleIntegralType a -> IntegralType (Vec n a) --- | All scalar element types implement Eq & Ord +-- Note: [Arbitrary width integers] -- -data ScalarType a where - SingleScalarType :: SingleType a -> ScalarType a - VectorScalarType :: VectorType (Vec n a) -> ScalarType (Vec n a) - -data SingleType a where - NumSingleType :: NumType a -> SingleType a +-- We support arbitrary width signed and unsigned integers, but for almost all +-- cases you should use the type synonyms generated below. +-- +data SingleIntegralType :: Type -> Type where + TypeInt :: {-# UNPACK #-} !Int -> SingleIntegralType a + TypeWord :: {-# UNPACK #-} !Int -> SingleIntegralType a -data VectorType a where - VectorType :: KnownNat n => {-# UNPACK #-} !Int -> SingleType a -> VectorType (Vec n a) +data FloatingType a where + SingleFloatingType :: SingleFloatingType a -> FloatingType a + VectorFloatingType :: KnownNat n => Proxy# n -> SingleFloatingType a -> FloatingType (Vec n a) + +data SingleFloatingType a where + -- TypeFloat8 :: SingleFloatingType Float8 + -- TypeBFloat16 :: SingleFloatingType BFloat16 + TypeFloat16 :: SingleFloatingType Float16 + TypeFloat32 :: SingleFloatingType Float32 + TypeFloat64 :: SingleFloatingType Float64 + TypeFloat128 :: SingleFloatingType Float128 + + +typeIntBits :: SingleIntegralType a -> Maybe Int +typeIntBits (TypeInt x) = Just x +typeIntBits _ = Nothing + +typeWordBits :: SingleIntegralType a -> Maybe Int +typeWordBits (TypeWord x) = Just x +typeWordBits _ = Nothing + +-- Generate pattern synonyms for fixed sized signed and unsigned integers. In +-- practice this is what we'll use most of the time, but occasionally we need to +-- convert via an arbitrary width integer (e.g. coercing between a BitMask and +-- an integral value). +-- +-- SEE: [Arbitrary width integers] +-- +runQ $ do + let + integralTypes :: [Integer] + integralTypes = [8,16,32,64,128] + + mkIntegral :: String -> Integer -> Q [Dec] + mkIntegral name bits = do + let t = conT $ mkName $ printf "%s%d" name bits + e = varE $ mkName $ printf "type%sBits" name + c = mkName $ printf "Type%s%d" name bits + -- + a <- newName "a" + sequence [ patSynSigD c (forallT [plainInvisTV a specifiedSpec] (return [TupleT 0]) [t| () => ($(varT a) ~ $t) => SingleIntegralType $(varT a) |]) + , patSynD c (prefixPatSyn []) (explBidir [clause [] (normalB [| $(conE (mkName (printf "Type%s" name))) $(litE (integerL bits)) |]) [] ]) + [p| (\x -> ($e x, unsafeCoerce Refl) -> (Just $(litP (integerL bits)), Refl :: $(varT a) :~: $t)) |] + ] + -- + cs <- pragCompleteD [ mkName (printf "Type%s%d" name bits) | name <- ["Int", "Word" :: String], bits <- integralTypes ] Nothing + ss <- mapM (mkIntegral "Int") integralTypes + us <- mapM (mkIntegral "Word") integralTypes + return (cs : concat ss ++ concat us) instance Show (IntegralType a) where - show TypeInt = "Int" - show TypeInt8 = "Int8" - show TypeInt16 = "Int16" - show TypeInt32 = "Int32" - show TypeInt64 = "Int64" - show TypeWord = "Word" - show TypeWord8 = "Word8" - show TypeWord16 = "Word16" - show TypeWord32 = "Word32" - show TypeWord64 = "Word64" + show (SingleIntegralType t) = show t + show (VectorIntegralType n t) = printf "<%d x %s>" (natVal' n) (show t) instance Show (FloatingType a) where - show TypeHalf = "Half" - show TypeFloat = "Float" - show TypeDouble = "Double" + show (SingleFloatingType t) = show t + show (VectorFloatingType n t) = printf "<%d x %s>" (natVal' n) (show t) -instance Show (NumType a) where - show (IntegralNumType ty) = show ty - show (FloatingNumType ty) = show ty +instance Show (SingleIntegralType a) where + show (TypeInt n) = printf "Int%d" n + show (TypeWord n) = printf "Word%d" n -instance Show (BoundedType a) where - show (IntegralBoundedType ty) = show ty +instance Show (SingleFloatingType a) where + show TypeFloat16 = "Float16" + show TypeFloat32 = "Float32" + show TypeFloat64 = "Float64" + show TypeFloat128 = "Float128" -instance Show (SingleType a) where - show (NumSingleType ty) = show ty +instance Show (NumType a) where + show (IntegralNumType t) = show t + show (FloatingNumType t) = show t -instance Show (VectorType a) where - show (VectorType n ty) = printf "<%d x %s>" n (show ty) +instance Show (BitType t) where + show (TypeBit) = "Bit" + show (TypeMask n) = printf "<%d x Bit>" (natVal' n) instance Show (ScalarType a) where - show (SingleScalarType ty) = show ty - show (VectorScalarType ty) = show ty + show (NumScalarType t) = show t + show (BitScalarType t) = show t formatIntegralType :: Format r (IntegralType a -> r) formatIntegralType = later $ \case - TypeInt -> "Int" - TypeInt8 -> "Int8" - TypeInt16 -> "Int16" - TypeInt32 -> "Int32" - TypeInt64 -> "Int64" - TypeWord -> "Word" - TypeWord8 -> "Word8" - TypeWord16 -> "Word16" - TypeWord32 -> "Word32" - TypeWord64 -> "Word64" + SingleIntegralType t -> bformat formatSingleIntegralType t + VectorIntegralType n t -> bformat (angled (int % " x " % formatSingleIntegralType)) (natVal' n) t + +formatSingleIntegralType :: Format r (SingleIntegralType a -> r) +formatSingleIntegralType = later $ \case + TypeInt n -> bformat ("Int" % int) n + TypeWord n -> bformat ("Word" % int) n formatFloatingType :: Format r (FloatingType a -> r) formatFloatingType = later $ \case - TypeHalf -> "Half" - TypeFloat -> "Float" - TypeDouble -> "Double" + SingleFloatingType t -> bformat formatSingleFloatingType t + VectorFloatingType n t -> bformat (angled (int % " x " % formatSingleFloatingType)) (natVal' n) t + +formatSingleFloatingType :: Format r (SingleFloatingType a -> r) +formatSingleFloatingType = later $ \case + TypeFloat16 -> "Float16" + TypeFloat32 -> "Float32" + TypeFloat64 -> "Float64" + TypeFloat128 -> "Float128" formatNumType :: Format r (NumType a -> r) formatNumType = later $ \case - IntegralNumType ty -> bformat formatIntegralType ty - FloatingNumType ty -> bformat formatFloatingType ty - -formatBoundedType :: Format r (BoundedType a -> r) -formatBoundedType = later $ \case - IntegralBoundedType ty -> bformat formatIntegralType ty - -formatSingleType :: Format r (SingleType a -> r) -formatSingleType = later $ \case - NumSingleType ty -> bformat formatNumType ty + IntegralNumType t -> bformat formatIntegralType t + FloatingNumType t -> bformat formatFloatingType t -formatVectorType :: Format r (VectorType a -> r) -formatVectorType = later $ \case - VectorType n ty -> bformat (angled (int % " x " % formatSingleType)) n ty +formatBitType :: Format r (BitType t -> r) +formatBitType = later $ \case + TypeBit -> "Bit" + TypeMask n -> bformat (angled (int % " x Bit")) (natVal' n) formatScalarType :: Format r (ScalarType a -> r) formatScalarType = later $ \case - SingleScalarType ty -> bformat formatSingleType ty - VectorScalarType ty -> bformat formatVectorType ty + NumScalarType t -> bformat formatNumType t + BitScalarType t -> bformat formatBitType t --- | Querying Integral types --- -class (IsSingle a, IsNum a, IsBounded a) => IsIntegral a where - integralType :: IntegralType a - --- | Querying Floating types --- -class (Floating a, IsSingle a, IsNum a) => IsFloating a where - floatingType :: FloatingType a - --- | Querying Numeric types --- -class (Num a, IsSingle a) => IsNum a where - numType :: NumType a - --- | Querying Bounded types --- -class IsBounded a where - boundedType :: BoundedType a - --- | Querying single value types --- -class IsScalar a => IsSingle a where - singleType :: SingleType a - --- | Querying all scalar types --- -class IsScalar a where - scalarType :: ScalarType a - - -integralDict :: IntegralType a -> IntegralDict a -integralDict TypeInt = IntegralDict -integralDict TypeInt8 = IntegralDict -integralDict TypeInt16 = IntegralDict -integralDict TypeInt32 = IntegralDict -integralDict TypeInt64 = IntegralDict -integralDict TypeWord = IntegralDict -integralDict TypeWord8 = IntegralDict -integralDict TypeWord16 = IntegralDict -integralDict TypeWord32 = IntegralDict -integralDict TypeWord64 = IntegralDict - -floatingDict :: FloatingType a -> FloatingDict a -floatingDict TypeHalf = FloatingDict -floatingDict TypeFloat = FloatingDict -floatingDict TypeDouble = FloatingDict - -singleDict :: SingleType a -> SingleDict a -singleDict = single - where - single :: SingleType a -> SingleDict a - single (NumSingleType t) = num t - - num :: NumType a -> SingleDict a - num (IntegralNumType t) = integral t - num (FloatingNumType t) = floating t - - integral :: IntegralType a -> SingleDict a - integral TypeInt = SingleDict - integral TypeInt8 = SingleDict - integral TypeInt16 = SingleDict - integral TypeInt32 = SingleDict - integral TypeInt64 = SingleDict - integral TypeWord = SingleDict - integral TypeWord8 = SingleDict - integral TypeWord16 = SingleDict - integral TypeWord32 = SingleDict - integral TypeWord64 = SingleDict - - floating :: FloatingType a -> SingleDict a - floating TypeHalf = SingleDict - floating TypeFloat = SingleDict - floating TypeDouble = SingleDict - - -scalarTypeInt :: ScalarType Int -scalarTypeInt = SingleScalarType $ NumSingleType $ IntegralNumType TypeInt - -scalarTypeWord :: ScalarType Word -scalarTypeWord = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord - -scalarTypeInt32 :: ScalarType Int32 -scalarTypeInt32 = SingleScalarType $ NumSingleType $ IntegralNumType TypeInt32 - -scalarTypeWord8 :: ScalarType Word8 -scalarTypeWord8 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord8 - -scalarTypeWord32 :: ScalarType Word32 -scalarTypeWord32 = SingleScalarType $ NumSingleType $ IntegralNumType TypeWord32 - rnfScalarType :: ScalarType t -> () -rnfScalarType (SingleScalarType t) = rnfSingleType t -rnfScalarType (VectorScalarType t) = rnfVectorType t - -rnfSingleType :: SingleType t -> () -rnfSingleType (NumSingleType t) = rnfNumType t +rnfScalarType (NumScalarType t) = rnfNumType t +rnfScalarType (BitScalarType t) = rnfBitType t -rnfVectorType :: VectorType t -> () -rnfVectorType (VectorType !_ t) = rnfSingleType t - -rnfBoundedType :: BoundedType t -> () -rnfBoundedType (IntegralBoundedType t) = rnfIntegralType t +rnfBitType :: BitType t -> () +rnfBitType TypeBit = () +rnfBitType (TypeMask !_) = () rnfNumType :: NumType t -> () rnfNumType (IntegralNumType t) = rnfIntegralType t rnfNumType (FloatingNumType t) = rnfFloatingType t rnfIntegralType :: IntegralType t -> () -rnfIntegralType TypeInt = () -rnfIntegralType TypeInt8 = () -rnfIntegralType TypeInt16 = () -rnfIntegralType TypeInt32 = () -rnfIntegralType TypeInt64 = () -rnfIntegralType TypeWord = () -rnfIntegralType TypeWord8 = () -rnfIntegralType TypeWord16 = () -rnfIntegralType TypeWord32 = () -rnfIntegralType TypeWord64 = () +rnfIntegralType (SingleIntegralType t) = rnfSingleIntegralType t +rnfIntegralType (VectorIntegralType !_ t) = rnfSingleIntegralType t + +rnfSingleIntegralType :: SingleIntegralType t -> () +rnfSingleIntegralType (TypeInt !_) = () +rnfSingleIntegralType (TypeWord !_) = () rnfFloatingType :: FloatingType t -> () -rnfFloatingType TypeHalf = () -rnfFloatingType TypeFloat = () -rnfFloatingType TypeDouble = () +rnfFloatingType (SingleFloatingType t) = rnfSingleFloatingType t +rnfFloatingType (VectorFloatingType !_ t) = rnfSingleFloatingType t +rnfSingleFloatingType :: SingleFloatingType t -> () +rnfSingleFloatingType TypeFloat16 = () +rnfSingleFloatingType TypeFloat32 = () +rnfSingleFloatingType TypeFloat64 = () +rnfSingleFloatingType TypeFloat128 = () -liftScalar :: ScalarType t -> t -> CodeQ t -liftScalar (SingleScalarType t) = liftSingle t -liftScalar (VectorScalarType t) = liftVector t -liftSingle :: SingleType t -> t -> CodeQ t -liftSingle (NumSingleType t) = liftNum t +liftScalar :: ScalarType t -> t -> CodeQ t +liftScalar (NumScalarType t) = liftNum t +liftScalar (BitScalarType t) = liftBit t -liftVector :: VectorType t -> t -> CodeQ t -liftVector VectorType{} = liftVec +liftBit :: BitType t -> t -> CodeQ t +liftBit TypeBit (Bit x) = [|| Bit x ||] +liftBit TypeMask{} x = liftVec x liftNum :: NumType t -> t -> CodeQ t liftNum (IntegralNumType t) = liftIntegral t liftNum (FloatingNumType t) = liftFloating t liftIntegral :: IntegralType t -> t -> CodeQ t -liftIntegral TypeInt x = [|| x ||] -liftIntegral TypeInt8 x = [|| x ||] -liftIntegral TypeInt16 x = [|| x ||] -liftIntegral TypeInt32 x = [|| x ||] -liftIntegral TypeInt64 x = [|| x ||] -liftIntegral TypeWord x = [|| x ||] -liftIntegral TypeWord8 x = [|| x ||] -liftIntegral TypeWord16 x = [|| x ||] -liftIntegral TypeWord32 x = [|| x ||] -liftIntegral TypeWord64 x = [|| x ||] +liftIntegral (SingleIntegralType t) = liftSingleIntegral t +liftIntegral (VectorIntegralType _ _) = liftVec + +liftSingleIntegral :: SingleIntegralType t -> t -> CodeQ t +liftSingleIntegral TypeInt8 x = [|| x ||] +liftSingleIntegral TypeInt16 x = [|| x ||] +liftSingleIntegral TypeInt32 x = [|| x ||] +liftSingleIntegral TypeInt64 x = [|| x ||] +liftSingleIntegral TypeWord8 x = [|| x ||] +liftSingleIntegral TypeWord16 x = [|| x ||] +liftSingleIntegral TypeWord32 x = [|| x ||] +liftSingleIntegral TypeWord64 x = [|| x ||] +liftSingleIntegral TypeInt128 (Int128 x y) = [|| Int128 x y ||] +liftSingleIntegral TypeWord128 (Word128 x y) = [|| Word128 x y ||] liftFloating :: FloatingType t -> t -> CodeQ t -liftFloating TypeHalf x = [|| x ||] -liftFloating TypeFloat x = [|| x ||] -liftFloating TypeDouble x = [|| x ||] +liftFloating (SingleFloatingType t) = liftSingleFloating t +liftFloating (VectorFloatingType _ _) = liftVec +liftSingleFloating :: SingleFloatingType t -> t -> CodeQ t +liftSingleFloating TypeFloat16 x = [|| x ||] +liftSingleFloating TypeFloat32 x = [|| x ||] +liftSingleFloating TypeFloat64 x = [|| x ||] +liftSingleFloating TypeFloat128 (Float128 x y) = [|| Float128 x y ||] -liftScalarType :: ScalarType t -> CodeQ (ScalarType t) -liftScalarType (SingleScalarType t) = [|| SingleScalarType $$(liftSingleType t) ||] -liftScalarType (VectorScalarType t) = [|| VectorScalarType $$(liftVectorType t) ||] -liftSingleType :: SingleType t -> CodeQ (SingleType t) -liftSingleType (NumSingleType t) = [|| NumSingleType $$(liftNumType t) ||] +liftScalarType :: ScalarType t -> CodeQ (ScalarType t) +liftScalarType (NumScalarType t) = [|| NumScalarType $$(liftNumType t) ||] +liftScalarType (BitScalarType t) = [|| BitScalarType $$(liftBitType t) ||] -liftVectorType :: VectorType t -> CodeQ (VectorType t) -liftVectorType (VectorType n t) = [|| VectorType n $$(liftSingleType t) ||] +liftBitType :: BitType t -> CodeQ (BitType t) +liftBitType TypeBit = [|| TypeBit ||] +liftBitType TypeMask{} = [|| TypeMask proxy# ||] liftNumType :: NumType t -> CodeQ (NumType t) liftNumType (IntegralNumType t) = [|| IntegralNumType $$(liftIntegralType t) ||] liftNumType (FloatingNumType t) = [|| FloatingNumType $$(liftFloatingType t) ||] -liftBoundedType :: BoundedType t -> CodeQ (BoundedType t) -liftBoundedType (IntegralBoundedType t) = [|| IntegralBoundedType $$(liftIntegralType t) ||] - liftIntegralType :: IntegralType t -> CodeQ (IntegralType t) -liftIntegralType TypeInt = [|| TypeInt ||] -liftIntegralType TypeInt8 = [|| TypeInt8 ||] -liftIntegralType TypeInt16 = [|| TypeInt16 ||] -liftIntegralType TypeInt32 = [|| TypeInt32 ||] -liftIntegralType TypeInt64 = [|| TypeInt64 ||] -liftIntegralType TypeWord = [|| TypeWord ||] -liftIntegralType TypeWord8 = [|| TypeWord8 ||] -liftIntegralType TypeWord16 = [|| TypeWord16 ||] -liftIntegralType TypeWord32 = [|| TypeWord32 ||] -liftIntegralType TypeWord64 = [|| TypeWord64 ||] +liftIntegralType (SingleIntegralType t) = [|| SingleIntegralType $$(liftSingleIntegralType t) ||] +liftIntegralType (VectorIntegralType _ t) = [|| VectorIntegralType proxy# $$(liftSingleIntegralType t) ||] + +liftSingleIntegralType :: SingleIntegralType t -> CodeQ (SingleIntegralType t) +liftSingleIntegralType (TypeInt n) = [|| TypeInt n ||] +liftSingleIntegralType (TypeWord n) = [|| TypeWord n ||] liftFloatingType :: FloatingType t -> CodeQ (FloatingType t) -liftFloatingType TypeHalf = [|| TypeHalf ||] -liftFloatingType TypeFloat = [|| TypeFloat ||] -liftFloatingType TypeDouble = [|| TypeDouble ||] +liftFloatingType (SingleFloatingType t) = [|| SingleFloatingType $$(liftSingleFloatingType t) ||] +liftFloatingType (VectorFloatingType _ t) = [|| VectorFloatingType proxy# $$(liftSingleFloatingType t) ||] + +liftSingleFloatingType :: SingleFloatingType t -> CodeQ (SingleFloatingType t) +liftSingleFloatingType TypeFloat16 = [|| TypeFloat16 ||] +liftSingleFloatingType TypeFloat32 = [|| TypeFloat32 ||] +liftSingleFloatingType TypeFloat64 = [|| TypeFloat64 ||] +liftSingleFloatingType TypeFloat128 = [|| TypeFloat128 ||] --- Type-level bit sizes --- -------------------- +-- Querying types +-- -------------- + +class IsScalar a where + scalarType :: ScalarType a + +class IsBit a where + bitType :: BitType a + +class IsNum a where + numType :: NumType a + +class IsIntegral a where + integralType :: IntegralType a + +class IsFloating a where + floatingType :: FloatingType a + +class IsSingleIntegral a where + singleIntegralType :: SingleIntegralType a + +class IsSingleFloating a where + singleFloatingType :: SingleFloatingType a -- | Constraint that values of these two types have the same bit width -- type BitSizeEq a b = (BitSize a == BitSize b) ~ 'True type family BitSize a :: Nat - --- Instances --- --------- --- --- Generate instances for the IsX classes. It would be preferable to do this --- automatically based on the members of the IntegralType (etc.) representations --- (see for example FromIntegral.hs) but TH phase restrictions would require us --- to split this into a separate module. --- - runQ $ do let - bits :: FiniteBits b => b -> Integer - bits = toInteger . finiteBitSize - - integralTypes :: [(Name, Integer)] - integralTypes = - [ (''Int, bits (undefined::Int)) - , (''Int8, 8) - , (''Int16, 16) - , (''Int32, 32) - , (''Int64, 64) - , (''Word, bits (undefined::Word)) - , (''Word8, 8) - , (''Word16, 16) - , (''Word32, 32) - , (''Word64, 64) - ] + integralTypes :: [Integer] + integralTypes = [8,16,32,64,128] floatingTypes :: [(Name, Integer)] floatingTypes = - [ (''Half, 16) - , (''Float, 32) - , (''Double, 64) + [ (''Half, 16) + , (''Float, 32) + , (''Double, 64) + , (''Float128, 128) ] - vectorTypes :: [(Name, Integer)] - vectorTypes = integralTypes ++ floatingTypes + mkIntegral :: String -> Integer -> Q [Dec] + mkIntegral name bits = + let t = conT $ mkName $ printf "%s%d" name bits + c = conE $ mkName $ printf "Type%s%d" name bits + in + [d| instance IsScalar $t where + scalarType = NumScalarType numType - mkIntegral :: Name -> Integer -> Q [Dec] - mkIntegral t n = - [d| instance IsIntegral $(conT t) where - integralType = $(conE (mkName ("Type" ++ nameBase t))) - - instance IsNum $(conT t) where + instance IsNum $t where numType = IntegralNumType integralType - instance IsBounded $(conT t) where - boundedType = IntegralBoundedType integralType + instance IsIntegral $t where + integralType = SingleIntegralType singleIntegralType + + instance IsSingleIntegral $t where + singleIntegralType = $c - instance IsSingle $(conT t) where - singleType = NumSingleType numType + instance KnownNat n => IsIntegral (Vec n $t) where + integralType = VectorIntegralType proxy# $c - instance IsScalar $(conT t) where - scalarType = SingleScalarType singleType + instance KnownNat n => IsNum (Vec n $t) where + numType = IntegralNumType integralType - type instance BitSize $(conT t) = $(litT (numTyLit n)) + instance KnownNat n => IsScalar (Vec n $t) where + scalarType = NumScalarType numType + + type instance BitSize $t = $(litT (numTyLit bits)) + type instance BitSize (Vec n $t) = n GHC.TypeNats.* $(litT (numTyLit bits)) |] mkFloating :: Name -> Integer -> Q [Dec] - mkFloating t n = - [d| instance IsFloating $(conT t) where - floatingType = $(conE (mkName ("Type" ++ nameBase t))) - - instance IsNum $(conT t) where + mkFloating name bits = + let t = conT name + c = conE $ mkName $ printf "TypeFloat%d" bits + in + [d| instance IsScalar $t where + scalarType = NumScalarType numType + + instance IsNum $t where numType = FloatingNumType floatingType - instance IsSingle $(conT t) where - singleType = NumSingleType numType + instance IsFloating $t where + floatingType = SingleFloatingType singleFloatingType - instance IsScalar $(conT t) where - scalarType = SingleScalarType singleType + instance IsSingleFloating $t where + singleFloatingType = $c - type instance BitSize $(conT t) = $(litT (numTyLit n)) - |] + instance KnownNat n => IsFloating (Vec n $t) where + floatingType = VectorFloatingType proxy# $c + + instance KnownNat n => IsNum (Vec n $t) where + numType = FloatingNumType floatingType - mkVector :: Name -> Integer -> Q [Dec] - mkVector t n = - [d| instance KnownNat n => IsScalar (Vec n $(conT t)) where - scalarType = VectorScalarType (VectorType (fromIntegral (natVal' (proxy# :: Proxy# n))) singleType) + instance KnownNat n => IsScalar (Vec n $t) where + scalarType = NumScalarType numType - type instance BitSize (Vec w $(conT t)) = w GHC.TypeLits.* $(litT (numTyLit n)) + type instance BitSize $t = $(litT (numTyLit bits)) + type instance BitSize (Vec n $t) = n GHC.TypeNats.* $(litT (numTyLit bits)) |] - -- - is <- mapM (uncurry mkIntegral) integralTypes + + ss <- mapM (mkIntegral "Int") integralTypes + us <- mapM (mkIntegral "Word") integralTypes fs <- mapM (uncurry mkFloating) floatingTypes - vs <- mapM (uncurry mkVector) vectorTypes -- - return (concat is ++ concat fs ++ concat vs) + return (concat ss ++ concat us ++ concat fs) + +type instance BitSize Bit = 1 +type instance BitSize (Vec n Bit) = n + +instance IsScalar Bit where + scalarType = BitScalarType bitType + +instance KnownNat n => IsScalar (Vec n Bit) where + scalarType = BitScalarType bitType + +instance IsBit Bit where + bitType = TypeBit + +instance KnownNat n => IsBit (Vec n Bit) where + bitType = TypeMask proxy# + +-- Determine the underlying type of a Haskell Int and Word +-- +runQ [d| type INT = $( + case finiteBitSize (undefined::Int) of + 8 -> [t| Int8 |] + 16 -> [t| Int16 |] + 32 -> [t| Int32 |] + 64 -> [t| Int64 |] + _ -> error "I don't know what architecture I am" ) |] + +runQ [d| type WORD = $( + case finiteBitSize (undefined::Word) of + 8 -> [t| Word8 |] + 16 -> [t| Word16 |] + 32 -> [t| Word32 |] + 64 -> [t| Word64 |] + _ -> error "I don't know what architecture I am" ) |] diff --git a/src/Data/Array/Accelerate/Unsafe.hs b/src/Data/Array/Accelerate/Unsafe.hs index 289cf7a2b..70b7fc677 100644 --- a/src/Data/Array/Accelerate/Unsafe.hs +++ b/src/Data/Array/Accelerate/Unsafe.hs @@ -1,5 +1,8 @@ -{-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} -- | -- Module : Data.Array.Accelerate.Unsafe -- Copyright : [2009..2020] The Accelerate Team @@ -24,10 +27,16 @@ module Data.Array.Accelerate.Unsafe ( import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Array +import Data.Array.Accelerate.Sugar.Shape --- | The function 'coerce' allows you to convert a value between any two types --- whose underlying representations have the same bit size at each component. +-- | The class 'Coercible' reinterprets the bits of a value or array of values +-- as that of a different type. +-- +-- At the expression level, this allows you to convert a value between any two +-- types whose underlying representations have the same bit size at each +-- component. -- -- For example: -- @@ -42,11 +51,38 @@ import Data.Array.Accelerate.Sugar.Elt -- abstract type to the concrete type by dropping the extra @()@ from the -- representation, and vice-versa. -- --- The type class 'Coerce' assures that there is a coercion between the two --- types. +-- At the array level this may also entail changing the size of the innermost +-- dimension. -- --- @since 1.2.0.0 +-- For example: +-- +-- > coerce (x :: Acc (Vector Float)) :: Acc (Vector (Complex Float)) +-- +-- will result in an array with half as many elements, as each element now +-- consists of two values (the real and imaginary values laid out consecutively +-- in memory, and now interpreted as a single packed 'Vec 2 Float'). For this to +-- be safe, the size of 'x' must therefore be even. +-- +-- Note that when applied at the array level 'coerce' prevents array fusion. +-- Therefore if the bit size of the source and target value types is the same, +-- then: +-- +-- > map f . coerce . map g -- -coerce :: Coerce (EltR a) (EltR b) => Exp a -> Exp b -coerce = mkCoerce +-- will result in two kernels being executed, whereas: +-- +-- > map f . map coerce . map g +-- +-- will fuse into a single kernel. +-- +-- @since 1.4.0.0 +-- +class Coercible f a b where + coerce :: f a -> f b + +instance Acoerce (EltR a) (EltR b) => Coercible Acc (Array (sh :. Int) a) (Array (sh :. Int) b) where + coerce = mkAcoerce + +instance Coerce (EltR a) (EltR b) => Coercible Exp a b where + coerce = mkCoerce diff --git a/src/Data/Numeric/Float128.hs b/src/Data/Numeric/Float128.hs new file mode 100644 index 000000000..fba3757d9 --- /dev/null +++ b/src/Data/Numeric/Float128.hs @@ -0,0 +1,591 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE UnboxedTuples #-} +{-# OPTIONS_GHC -fobject-code #-} +{-# OPTIONS_HADDOCK hide #-} +-- | +-- Module : Data.Numeric.Float128 +-- Copyright : [2008..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- +-- IEEE 128-bit floating point type (quadruple precision), consisting of +-- a 15-bit signed exponent and 113-bit mantissa (compared to 11-bit and +-- 53-bit respectively for IEEE double precision). +-- +-- Partly stolen from the (unmaintained) float128 package (BSD3 license). +-- We required the Float128 data type even if operations on that type are +-- not implemented, and link against the methods from the quadmath library +-- (as this is what LLVM will generate). +-- + +module Data.Numeric.Float128 ( + + Float128(..), + +) where + +import Numeric +import Data.Bits +import Data.Primitive.Types +import Data.Ratio +import Foreign.Marshal.Alloc +import Foreign.Marshal.Utils +import Foreign.Ptr +import Foreign.Storable +import System.IO.Unsafe + +import GHC.Base +import GHC.Int +import GHC.Integer.Logarithms +import GHC.Word + +#if defined(FLOAT128_ENABLE) && !defined(__GHCIDE__) +import Language.Haskell.TH.Syntax +#endif + + +-- | A 128-bit floating point number +-- +data Float128 = Float128 !Word64 !Word64 + deriving Eq + +instance Show Float128 where + showsPrec p x = showParen (x < 0 && p > 6) (showFloat x) + +instance Read Float128 where + readsPrec _ = readSigned readFloat + +instance Ord Float128 where + (<) = cmp c_ltq + (>) = cmp c_gtq + (<=) = cmp c_leq + (>=) = cmp c_geq + min = call2 c_fminq + max = call2 c_fmaxq + +instance Num Float128 where + (+) = call2 c_addq + (-) = call2 c_subq + (*) = call2 c_mulq + negate = call1 c_negateq + abs = call1 c_absq + signum = call1 c_signumq + fromInteger z = encodeFloat z 0 + +instance Fractional Float128 where + (/) = call2 c_divq + recip = call1 c_recipq + fromRational q = -- FIXME accuracy? + let a = fromInteger (numerator q) / fromInteger (denominator q) + b = fromInteger (numerator r) / fromInteger (denominator r) + r = q - toRational a + in a + b + +instance Floating Float128 where + pi = unsafePerformIO $ alloca $ \pr -> do + c_piq pr + peek pr + exp = call1 c_expq + log = call1 c_logq + sqrt = call1 c_sqrtq + (**) = call2 c_powq + sin = call1 c_sinq + cos = call1 c_cosq + tan = call1 c_tanq + asin = call1 c_asinq + acos = call1 c_acosq + atan = call1 c_atanq + sinh = call1 c_sinhq + cosh = call1 c_coshq + tanh = call1 c_tanhq + asinh = call1 c_asinhq + acosh = call1 c_acoshq + atanh = call1 c_atanhq + log1p = call1 c_log1pq + expm1 = call1 c_expm1q + +instance Real Float128 where + toRational l = + case decodeFloat l of + (m, e) + | e >= 0 -> m `shiftL` e % 1 + | otherwise -> m % bit (negate e) + +instance RealFrac Float128 where + properFraction l + | l >= 0 = let n' = floor' l + n = fromInteger (toInteger' n') + f = l - n' + in (n, f) + | l < 0 = let n' = ceil' l + n = fromInteger (toInteger' n') + f = l - n' + in (n, f) + | otherwise = (0, l) -- NaN + + truncate = fromInteger . toInteger' . trunc' + round = fromInteger . toInteger' . round' + floor = fromInteger . toInteger' . floor' + ceiling = fromInteger . toInteger' . ceil' + +toInteger' :: Float128 -> Integer +toInteger' l = + case decodeFloat l of + (m, e) + | e >= 0 -> m `shiftL` e + | otherwise -> m `shiftR` negate e + +round', trunc', floor', ceil' :: Float128 -> Float128 +round' = call1 c_roundq +trunc' = call1 c_truncq +floor' = call1 c_floorq +ceil' = call1 c_ceilq + +instance RealFloat Float128 where + isIEEE _ = True + floatRadix _ = 2 + floatDigits _ = 113 -- quadmath.h:FLT128_MANT_DIG + floatRange _ = (-16381,16384) -- quadmath.h:FLT128_MIN_EXP, FLT128_MAX_EXP + isNaN = tst c_isnanq + isInfinite = tst c_isinfq + isNegativeZero = tst c_isnegzeroq + isDenormalized = tst c_isdenormq + atan2 = call2 c_atan2q + + decodeFloat l@(Float128 msw lsw) + | isNaN l = (0, 0) + | isInfinite l = (0, 0) + | l == 0 = (0, 0) + | isDenormalized l = + case decodeFloat (scaleFloat 128 l) of + (m, e) -> (m, e - 128) + | otherwise = + let s = shiftR msw 48 `testBit` 15 + m0 = shiftL (0x1000000000000 .|. toInteger msw .&. 0xFFFFffffFFFF) 64 .|. toInteger lsw + e0 = shiftR msw 48 .&. (bit 15 - 1) + m = if s then negate m0 else m0 + e = fromIntegral e0 - 16383 - 112 -- FIXME verify + in (m, e) + + encodeFloat m e + | m == 0 = Float128 0 0 + | m < 0 = negate (encodeFloat (negate m) e) + | b >= bit 15 - 1 = Float128 (shiftL (bit 15 - 1) 48) 0 -- infinity + | b <= 0 = scaleFloat (b - 128) (encodeFloat m (e - b + 128)) -- denormal + | otherwise = Float128 msw lsw -- normal + where + l = I# (integerLog2# m) + t = l - 112 -- FIXME verify + m' | t >= 0 = m `shiftR` t + | otherwise = m `shiftL` negate t + -- FIXME: verify that m' `testBit` 112 == True + lsw = fromInteger (m' .&. 0xFFFFffffFFFFffff) + msw = fromInteger (shiftR m' 64 .&. 0xFFFFffffFFFF) .|. shiftL (fromIntegral b) 48 + b = e + t + 16383 + 112 -- FIXME verify + + exponent l@(Float128 msw _) + | isNaN l = 0 + | isInfinite l = 0 + | l == 0 = 0 + | isDenormalized l = snd (decodeFloat l) + 113 + | otherwise = let e0 = shiftR msw 48 .&. (bit 15 - 1) + in fromIntegral e0 - 16383 - 112 + 113 + + significand l = unsafePerformIO $ + with l $ \lp -> + alloca $ \ep -> do + c_frexpq lp lp ep + peek lp + + scaleFloat e l = unsafePerformIO $ + with l $ \lp -> do + c_ldexpq lp lp (fromIntegral e) + peek lp + +instance Storable Float128 where + {-# INLINE sizeOf #-} + {-# INLINE alignment #-} + {-# INLINE peek #-} + {-# INLINE poke #-} + sizeOf _ = 16 + alignment _ = 16 + peek = peek128 + peekElemOff = peekElemOff128 + poke = poke128 + pokeElemOff = pokeElemOff128 + +instance Prim Float128 where + {-# INLINE sizeOf# #-} + {-# INLINE alignment# #-} + {-# INLINE indexByteArray# #-} + {-# INLINE readByteArray# #-} + {-# INLINE writeByteArray# #-} + {-# INLINE setByteArray# #-} + {-# INLINE indexOffAddr# #-} + {-# INLINE readOffAddr# #-} + {-# INLINE writeOffAddr# #-} + {-# INLINE setOffAddr# #-} + sizeOf# _ = 16# + alignment# _ = 16# + indexByteArray# = indexByteArray128# + readByteArray# = readByteArray128# + writeByteArray# = writeByteArray128# + setByteArray# = setByteArray128# + indexOffAddr# = indexOffAddr128# + readOffAddr# = readOffAddr128# + writeOffAddr# = writeOffAddr128# + setOffAddr# = setOffAddr128# + +{-# INLINE peek128 #-} +peek128 :: Ptr Float128 -> IO Float128 +peek128 ptr = Float128 <$> peekElemOff (castPtr ptr) index1 + <*> peekElemOff (castPtr ptr) index0 + +{-# INLINE peekElemOff128 #-} +peekElemOff128 :: Ptr Float128 -> Int -> IO Float128 +peekElemOff128 ptr i = + let i2 = 2 * i + in Float128 <$> peekElemOff (castPtr ptr) (i2 + index1) + <*> peekElemOff (castPtr ptr) (i2 + index0) + +{-# INLINE poke128 #-} +poke128 :: Ptr Float128 -> Float128 -> IO () +poke128 ptr (Float128 a b) = do + pokeElemOff (castPtr ptr) index1 a + pokeElemOff (castPtr ptr) index0 b + +{-# INLINE pokeElemOff128 #-} +pokeElemOff128 :: Ptr Float128 -> Int -> Float128 -> IO () +pokeElemOff128 ptr i (Float128 a1 a0) = + let i2 = 2 * i + in do pokeElemOff (castPtr ptr) (i2 + index0) a0 + pokeElemOff (castPtr ptr) (i2 + index1) a1 + +{-# INLINE indexByteArray128# #-} +indexByteArray128# :: ByteArray# -> Int# -> Float128 +indexByteArray128# arr# i# = + let i2# = 2# *# i# + x = indexByteArray# arr# (i2# +# unInt index1) + y = indexByteArray# arr# (i2# +# unInt index0) + in Float128 x y + +{-# INLINE readByteArray128# #-} +readByteArray128# :: MutableByteArray# s -> Int# -> State# s -> (# State# s, Float128 #) +readByteArray128# arr# i# s0 = + let i2# = 2# *# i# + in case readByteArray# arr# (i2# +# unInt index1) s0 of { (# s1, x #) -> + case readByteArray# arr# (i2# +# unInt index0) s1 of { (# s2, y #) -> + (# s2, Float128 x y #) + }} + +{-# INLINE writeByteArray128# #-} +writeByteArray128# :: MutableByteArray# s -> Int# -> Float128 -> State# s -> State# s +writeByteArray128# arr# i# (Float128 a b) s0 = + let i2# = 2# *# i# + in case writeByteArray# arr# (i2# +# unInt index1) a s0 of { s1 -> + case writeByteArray# arr# (i2# +# unInt index0) b s1 of { s2 -> + s2 + }} + +{-# INLINE setByteArray128# #-} +setByteArray128# :: MutableByteArray# s -> Int# -> Int# -> Float128 -> State# s -> State# s +setByteArray128# = defaultSetByteArray# + +{-# INLINE indexOffAddr128# #-} +indexOffAddr128# :: Addr# -> Int# -> Float128 +indexOffAddr128# addr# i# = + let i2# = 2# *# i# + x = indexOffAddr# addr# (i2# +# unInt index1) + y = indexOffAddr# addr# (i2# +# unInt index0) + in Float128 x y + +{-# INLINE readOffAddr128# #-} +readOffAddr128# :: Addr# -> Int# -> State# s -> (# State# s, Float128 #) +readOffAddr128# addr# i# s0 = + let i2# = 2# *# i# + in case readOffAddr# addr# (i2# +# unInt index1) s0 of { (# s1, x #) -> + case readOffAddr# addr# (i2# +# unInt index0) s1 of { (# s2, y #) -> + (# s2, Float128 x y #) + }} + +{-# INLINE writeOffAddr128# #-} +writeOffAddr128# :: Addr# -> Int# -> Float128 -> State# s -> State# s +writeOffAddr128# addr# i# (Float128 a b) s0 = + let i2# = 2# *# i# + in case writeOffAddr# addr# (i2# +# unInt index1) a s0 of { s1 -> + case writeOffAddr# addr# (i2# +# unInt index0) b s1 of { s2 -> + s2 + }} + +{-# INLINE setOffAddr128# #-} +setOffAddr128# :: Addr# -> Int# -> Int# -> Float128 -> State# s -> State# s +setOffAddr128# = defaultSetOffAddr# + +{-# INLINE unInt #-} +unInt :: Int -> Int# +unInt (I# i#) = i# + +-- Use these indices to get the peek/poke ordering endian correct. +{-# INLINE index0 #-} +{-# INLINE index1 #-} +index0, index1 :: Int +#if WORDS_BIGENDIAN +index0 = 1 +index1 = 0 +#else +index0 = 0 +index1 = 1 +#endif + +type F1 = Ptr Float128 -> Ptr Float128 -> IO () +type F2 = Ptr Float128 -> Ptr Float128 -> Ptr Float128 -> IO () +type CMP = Ptr Float128 -> Ptr Float128 -> IO Int32 +type TST = Ptr Float128 -> IO Int32 + +{-# INLINE call1 #-} +call1 :: F1 -> Float128 -> Float128 +call1 f x = unsafePerformIO $ + alloca $ \pr -> + with x $ \px -> do + f pr px + peek pr + +{-# INLINE call2 #-} +call2 :: F2 -> Float128 -> Float128 -> Float128 +call2 f x y = unsafePerformIO $ + alloca $ \pr -> + with x $ \px -> do + with y $ \py -> do + f pr px py + peek pr + +{-# INLINE cmp #-} +cmp :: CMP -> Float128 -> Float128 -> Bool +cmp f x y = unsafePerformIO $ + with x $ \px -> + with y $ \py -> + toBool <$> f px py + +{-# INLINE tst #-} +tst :: TST -> Float128 -> Bool +tst f x = unsafePerformIO $ + with x $ \px -> + toBool <$> f px + + +-- SEE: [HLS and GHC IDE] +-- +#if defined(FLOAT128_ENABLE) && !defined(__GHCIDE__) + +foreign import ccall unsafe "_addq" c_addq :: F2 +foreign import ccall unsafe "_subq" c_subq :: F2 +foreign import ccall unsafe "_mulq" c_mulq :: F2 +foreign import ccall unsafe "_absq" c_absq :: F1 +foreign import ccall unsafe "_negateq" c_negateq :: F1 +foreign import ccall unsafe "_signumq" c_signumq :: F1 + +foreign import ccall unsafe "_divq" c_divq :: F2 +foreign import ccall unsafe "_recipq" c_recipq :: F1 + +foreign import ccall unsafe "_piq" c_piq :: Ptr Float128 -> IO () +foreign import ccall unsafe "_expq" c_expq :: F1 +foreign import ccall unsafe "_logq" c_logq :: F1 +foreign import ccall unsafe "_sqrtq" c_sqrtq :: F1 +foreign import ccall unsafe "_powq" c_powq :: F2 +foreign import ccall unsafe "_sinq" c_sinq :: F1 +foreign import ccall unsafe "_cosq" c_cosq :: F1 +foreign import ccall unsafe "_tanq" c_tanq :: F1 +foreign import ccall unsafe "_asinq" c_asinq :: F1 +foreign import ccall unsafe "_acosq" c_acosq :: F1 +foreign import ccall unsafe "_atanq" c_atanq :: F1 +foreign import ccall unsafe "_sinhq" c_sinhq :: F1 +foreign import ccall unsafe "_coshq" c_coshq :: F1 +foreign import ccall unsafe "_tanhq" c_tanhq :: F1 +foreign import ccall unsafe "_asinhq" c_asinhq :: F1 +foreign import ccall unsafe "_acoshq" c_acoshq :: F1 +foreign import ccall unsafe "_atanhq" c_atanhq :: F1 +foreign import ccall unsafe "_log1pq" c_log1pq :: F1 +foreign import ccall unsafe "_expm1q" c_expm1q :: F1 + +foreign import ccall unsafe "_roundq" c_roundq :: F1 +foreign import ccall unsafe "_truncq" c_truncq :: F1 +foreign import ccall unsafe "_floorq" c_floorq :: F1 +foreign import ccall unsafe "_ceilq" c_ceilq :: F1 + +foreign import ccall unsafe "_isnanq" c_isnanq :: TST +foreign import ccall unsafe "_isinfq" c_isinfq :: TST +foreign import ccall unsafe "_isnegzeroq" c_isnegzeroq :: TST +foreign import ccall unsafe "_isdenormq" c_isdenormq :: TST +foreign import ccall unsafe "_frexpq" c_frexpq :: Ptr Float128 -> Ptr Float128 -> Ptr Int32 -> IO () +foreign import ccall unsafe "_ldexpq" c_ldexpq :: Ptr Float128 -> Ptr Float128 -> Int32 -> IO () +foreign import ccall unsafe "_atan2q" c_atan2q :: F2 + +foreign import ccall unsafe "_ltq" c_ltq :: CMP +foreign import ccall unsafe "_ltq" c_gtq :: CMP +foreign import ccall unsafe "_ltq" c_leq :: CMP +foreign import ccall unsafe "_ltq" c_geq :: CMP +foreign import ccall unsafe "_fminq" c_fminq :: F2 +foreign import ccall unsafe "_fmaxq" c_fmaxq :: F2 + +-- foreign import ccall unsafe "_readq" c_readq :: Ptr Float128 -> Ptr CChar -> IO () +-- foreign import ccall unsafe "_showq" c_showq :: Ptr CChar -> CSize -> Ptr Float128 -> IO () + +-- SEE: [linking to .c files] +-- +runQ $ do + addForeignFilePath LangC "cbits/float128.c" + return [] + +#else + +c_addq :: F2 +c_addq = not_enabled + +c_subq :: F2 +c_subq = not_enabled + +c_mulq :: F2 +c_mulq = not_enabled + +c_absq :: F1 +c_absq = not_enabled + +c_negateq :: F1 +c_negateq = not_enabled + +c_signumq :: F1 +c_signumq = not_enabled + +c_divq :: F2 +c_divq = not_enabled + +c_recipq :: F1 +c_recipq = not_enabled + +c_piq :: Ptr Float128 -> IO () +c_piq = not_enabled + +c_expq :: F1 +c_expq = not_enabled + +c_logq :: F1 +c_logq = not_enabled + +c_sqrtq :: F1 +c_sqrtq = not_enabled + +c_powq :: F2 +c_powq = not_enabled + +c_sinq :: F1 +c_sinq = not_enabled + +c_cosq :: F1 +c_cosq = not_enabled + +c_tanq :: F1 +c_tanq = not_enabled + +c_asinq :: F1 +c_asinq = not_enabled + +c_acosq :: F1 +c_acosq = not_enabled + +c_atanq :: F1 +c_atanq = not_enabled + +c_sinhq :: F1 +c_sinhq = not_enabled + +c_coshq :: F1 +c_coshq = not_enabled + +c_tanhq :: F1 +c_tanhq = not_enabled + +c_asinhq :: F1 +c_asinhq = not_enabled + +c_acoshq :: F1 +c_acoshq = not_enabled + +c_atanhq :: F1 +c_atanhq = not_enabled + +c_log1pq :: F1 +c_log1pq = not_enabled + +c_expm1q :: F1 +c_expm1q = not_enabled + +c_roundq :: F1 +c_roundq = not_enabled + +c_truncq :: F1 +c_truncq = not_enabled + +c_floorq :: F1 +c_floorq = not_enabled + +c_ceilq :: F1 +c_ceilq = not_enabled + +c_isnanq :: TST +c_isnanq = not_enabled + +c_isinfq :: TST +c_isinfq = not_enabled + +c_isnegzeroq :: TST +c_isnegzeroq = not_enabled + +c_isdenormq :: TST +c_isdenormq = not_enabled + +c_frexpq :: Ptr Float128 -> Ptr Float128 -> Ptr Int32 -> IO () +c_frexpq = not_enabled + +c_ldexpq :: Ptr Float128 -> Ptr Float128 -> Int32 -> IO () +c_ldexpq = not_enabled + +c_atan2q :: F2 +c_atan2q = not_enabled + +c_ltq :: CMP +c_ltq = not_enabled + +c_gtq :: CMP +c_gtq = not_enabled + +c_leq :: CMP +c_leq = not_enabled + +c_geq :: CMP +c_geq = not_enabled + +c_fminq :: F2 +c_fminq = not_enabled + +c_fmaxq :: F2 +c_fmaxq = not_enabled + +-- c_readq :: Ptr Float128 -> Ptr CChar -> IO () +-- c_readq = not_enabled +-- +-- c_showq :: Ptr CChar -> CSize -> Ptr Float128 -> IO () +-- c_showq = not_enabled + +not_enabled :: a +not_enabled = error $ + unlines [ "128-bit floating point numbers are not enabled." + , "Reinstall package 'accelerate' with '-ffloat128' to enable them." + ] +#endif + diff --git a/src/Data/Primitive/Bit.hs b/src/Data/Primitive/Bit.hs new file mode 100644 index 000000000..d286054e9 --- /dev/null +++ b/src/Data/Primitive/Bit.hs @@ -0,0 +1,360 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UnboxedTuples #-} +{-# OPTIONS_HADDOCK hide #-} +-- | +-- Module : Data.Primitive.Bit +-- Copyright : [2008..2022] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module Data.Primitive.Bit ( + + Bit(..), + BitMask(..), + toList, fromList, + extract, insert, zeros, ones, + +) where + +import Data.Array.Accelerate.Error + +import Control.Exception +import Control.Monad.ST +import Data.Bits +import Data.Typeable +import qualified Foreign.Storable as Foreign + +import Data.Primitive.ByteArray +import Data.Primitive.Vec ( Vec(..) ) + +import GHC.Base ( isTrue# ) +import GHC.Generics +import GHC.Int +import GHC.Ptr +import GHC.TypeLits +import GHC.Types ( IO(..) ) +import GHC.Word +import qualified GHC.Exts as GHC + +#if __GLASGOW_HASKELL__ < 902 +import GHC.Prim hiding ( subWord8# ) +#else +import GHC.Prim +#endif + + +-- | A newtype wrapper over 'Bool' whose instances pack bits as efficiently +-- as possible (8 values per byte). Arrays of 'Bit' use 8x less memory than +-- arrays of 'Bool' (which stores one value per byte). However, (parallel) +-- random writes are (almost certainly) slower. +-- +newtype Bit = Bit { unBit :: Bool } + deriving (Eq, Ord, Bounded, Enum, FiniteBits, Bits, Typeable, Generic) + +instance Show Bit where + showsPrec _ (Bit False) = showString "0" + showsPrec _ (Bit True) = showString "1" + +instance Read Bit where + readsPrec p = \case + ' ':rest -> readsPrec p rest + '0':rest -> [(Bit False, rest)] + '1':rest -> [(Bit True, rest)] + _ -> [] + +instance Num Bit where + Bit a * Bit b = Bit (a && b) + Bit a + Bit b = Bit (a /= b) + Bit a - Bit b = Bit (a /= b) + negate = id + abs = id + signum = id + fromInteger = Bit . odd + +instance Real Bit where + toRational = fromIntegral + +instance Integral Bit where + quotRem _ (Bit False) = throw DivideByZero + quotRem x (Bit True) = (x, Bit False) + toInteger (Bit False) = 0 + toInteger (Bit True) = 1 + + +-- | A SIMD vector of 'Bit's +-- +newtype BitMask n = BitMask { unMask :: Vec n Bit } + deriving Eq + -- XXX: We should mask off the unused bits before testing for equality, + -- otherwise we are including junk in the test. TLM 2022-06-07 + +instance KnownNat n => Show (BitMask n) where + show = bin . toList + where + bin :: [Bit] -> String + bin bs = '0':'b': go bs + -- + go [] = [] + go (Bit True :rest) = '1' : go rest + go (Bit False:rest) = '0' : go rest + + +instance KnownNat n => GHC.IsList (BitMask n) where + type Item (BitMask n) = Bit + {-# INLINE toList #-} + {-# INLINE fromList #-} + toList = toList + fromList = fromList + +instance KnownNat n => Foreign.Storable (BitMask n) where + {-# INLINE sizeOf #-} + {-# INLINE alignment #-} + {-# INLINE peek #-} + {-# INLINE poke #-} + {-# INLINE peekElemOff #-} + {-# INLINE pokeElemOff #-} + + alignment _ = 1 + sizeOf _ = + let k = fromIntegral (natVal' (proxy# :: Proxy# n)) + in quot (k + 7) 8 + + peek (Ptr addr#) = + let k = natVal' (proxy# :: Proxy# n) + in if k `rem` 8 /= 0 + then error "TODO: use BitMask.peekElemOff for non-multiple-of-8 sized bit-masks" + else IO $ \s0 -> + case Foreign.sizeOf (undefined :: BitMask n) of { I# bytes# -> + case newByteArray# bytes# s0 of { (# s1, mba# #) -> + case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> + case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> + (# s3, BitMask (Vec ba#) #) + }}}} + + peekElemOff (Ptr addr#) (I# i#) = + let !(I# k#) = fromInteger (natVal' (proxy# :: Proxy# n)) + !ki# = i# *# k# + in + if isTrue# (k# <=# 8#) + then let !(# q#, r# #) = quotRemInt# ki# 8# + !mask# = ((wordToWord8# 1## `uncheckedShiftLWord8#` k#) `subWord8#` wordToWord8# 1##) `uncheckedShiftLWord8#` (8# -# k#) + combine u# v# = (uncheckedShiftLWord8# u# r# `orWord8#` uncheckedShiftRLWord8# v# (8# -# r#)) `andWord8#` mask# + in + if isTrue# (r# +# k# <=# 8#) + -- This element does not not cross the byte boundary + then IO $ \s0 -> + case newByteArray# 1# s0 of { (# s1, mba# #) -> + case readWord8OffAddr# addr# q# s1 of { (# s2, w# #) -> + case writeWord8Array# mba# 0# (uncheckedShiftLWord8# w# r# `andWord8#` mask#) s2 of { s3 -> + case unsafeFreezeByteArray# mba# s3 of { (# s4, ba# #) -> + (# s4, BitMask (Vec ba#) #) + }}}} + -- This element crosses the byte boundary. Read two successive + -- bytes (note that on little-endian we can't just treat this + -- as a 16-bit load) to combine and extract the bits we need. + else IO $ \s0 -> + case newByteArray# 1# s0 of { (# s1, mba# #) -> + case readWord8OffAddr# addr# q# s1 of { (# s2, w0# #) -> + case readWord8OffAddr# addr# (q# +# 1#) s2 of { (# s3, w1# #) -> + case writeWord8Array# mba# 0# (combine w0# w1#) s3 of { s4 -> + case unsafeFreezeByteArray# mba# s4 of { (# s5, ba# #) -> + (# s5, BitMask (Vec ba#) #) + }}}}} + else + if isTrue# (k# <=# 16#) + then error "TODO: BitMask (8..16]" + else + if isTrue# (k# <=# 32#) + then error "TODO: BitMask (16..32]" + else + if isTrue# (k# <=# 64#) + then error "TODO: BitMask (32..64]" + else + error "TODO: BitMask.peekElemOff not yet supported at this size" + + poke (Ptr addr#) (BitMask (Vec ba#)) = + let k = natVal' (proxy# :: Proxy# n) + in if k `rem` 8 /= 0 + then error "TODO: use BitMask.pokeElemOff for non-multiple-of-8 sized bit-masks" + else IO $ \s0 -> + case Foreign.sizeOf (undefined :: BitMask n) of { I# bytes# -> + case copyByteArrayToAddr# ba# 0# addr# bytes# s0 of { + s1 -> (# s1, () #) + }} + + pokeElemOff (Ptr addr#) (I# i#) (BitMask (Vec ba#)) = + let !(I# k#) = fromInteger (natVal' (proxy# :: Proxy# n)) + !ki# = i# *# k# + in + if isTrue# (k# <=# 8#) + then let !(# q#, r# #) = quotRemInt# ki# 8# + !rk# = r# +# k# + in + if isTrue# (rk# <=# 8#) + -- This element does not cross the byte boundary + then let !w# = uncheckedShiftRLWord8# (indexWord8Array# ba# 0#) r# + !mask# = ((wordToWord8# 1## `uncheckedShiftLWord8#` k#) `subWord8#` wordToWord8# 1##) `uncheckedShiftLWord8#` (8# -# rk#) + in IO $ \s0 -> + case readWord8OffAddr# addr# q# s0 of { (# s1, v# #) -> + case writeWord8OffAddr# addr# q# ((v# `andWord8#` complementWord8# mask#) `orWord8#` (w# `andWord8#` mask#)) s1 of { s2 -> + (# s2, () #) + }} + -- This element crosses the byte boundary + else let !w# = indexWord8Array# ba# 0# + !w0# = w# `uncheckedShiftRLWord8#` r# + !w1# = w# `uncheckedShiftLWord8#` (8# -# r#) + !mask# = ((wordToWord8# 1## `uncheckedShiftLWord8#` k#) `subWord8#` wordToWord8# 1##) + !mask0# = mask# `uncheckedShiftRLWord8#` (rk# -# 8#) + !mask1# = mask# `uncheckedShiftLWord8#` (16# -# rk#) + in IO $ \s0 -> + case readWord8OffAddr# addr# q# s0 of { (# s1, v0# #) -> + case readWord8OffAddr# addr# (q# +# 1#) s1 of { (# s2, v1# #) -> + case writeWord8OffAddr# addr# q# ((v0# `andWord8#` complementWord8# mask0#) `orWord8#` (w0# `andWord8#` mask0#)) s2 of { s3 -> + case writeWord8OffAddr# addr# (q# +# 1#) ((v1# `andWord8#` complementWord8# mask1#) `orWord8#` (w1# `andWord8#` mask1#)) s3 of { s4 -> + (# s4, () #) + }}}} + else + error "TODO: BitMask.pokeElemOff not yet supported at this size" + +{-# INLINE toList #-} +toList :: forall n. KnownNat n => BitMask n -> [Bit] +toList (BitMask (Vec ba#)) = concat (unpack 0# []) + where + !(I# n#) = fromInteger (natVal' (proxy# :: Proxy# n)) + + unpack :: Int# -> [[Bit]] -> [[Bit]] + unpack i# acc + | isTrue# (i# <# n#) = + let q# = quotInt# i# 8# + w# = indexWord8Array# ba# q# + lim# = minInt# 8# (n# -# i#) + w8 j# acc' = + if isTrue# (j# <# lim#) + then w8 (j# +# 1#) (Bit (isTrue# (testBitWord8# w# j#)) : acc') + else acc' + in + unpack (i# +# 8#) (w8 0# [] : acc) + | otherwise = acc + +{-# INLINE fromList #-} +fromList :: forall n. KnownNat n => [Bit] -> BitMask n +fromList bits = case byteArrayFromListN bytes (pack bits' []) of + ByteArray ba# -> BitMask (Vec ba#) + where + bits' = take (fromInteger (natVal' (proxy# :: Proxy# n))) bits + bytes = Foreign.sizeOf (undefined :: BitMask n) + + pack :: [Bit] -> [Word8] -> [Word8] + pack [] acc = acc + pack xs acc = + let (h,t) = splitAt 8 xs + w = w8 0 0 h + in + pack t (w : acc) + + w8 :: Int -> Word8 -> [Bit] -> Word8 + w8 !_ !w [] = w + w8 !i !w (Bit True :bs) = w8 (i+1) (setBit w i) bs + w8 !i !w (Bit False:bs) = w8 (i+1) w bs + +{-# INLINE extract #-} +extract :: forall n. KnownNat n => BitMask n -> Int -> Bit +extract (BitMask (Vec ba#)) i@(I# i#) = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + !(# q#, r# #) = quotRemInt# i# 8# + w# = indexWord8Array# ba# q# + b# = testBitWord8# w# r# + in + boundsCheck "out of range" (i >= 0 && i < n) (Bit (isTrue# b#)) + +{-# INLINE insert #-} +insert :: forall n. KnownNat n => BitMask n -> Int -> Bit -> BitMask n +insert (BitMask (Vec ba#)) i (Bit b) = runST $ do + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + bytes = quot (n+7) 8 + (u,v) = quotRem i 8 + -- + mba <- newByteArray n + copyByteArray mba 0 (ByteArray ba#) 0 bytes + x :: Word8 <- readByteArray mba u + writeByteArray mba u $ if b then setBit x v + else clearBit x v + ByteArray ba'# <- unsafeFreezeByteArray mba + return (BitMask (Vec ba'#)) + +{-# INLINE zeros #-} +zeros :: forall n. KnownNat n => BitMask n +zeros = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + l = quot (n+7) 8 + in + case byteArrayFromListN l (replicate l (0 :: Word8)) of + ByteArray ba# -> BitMask (Vec ba#) + +{-# INLINE ones #-} +ones :: forall n. KnownNat n => BitMask n +ones = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + l = quot (n+7) 8 + in + case byteArrayFromListN l (replicate l (0xff :: Word8)) of + ByteArray ba# -> BitMask (Vec ba#) + + +minInt# :: Int# -> Int# -> Int# +minInt# a# b# = + case a# <# b# of + 0# -> b# + _ -> a# + +#if __GLASGOW_HASKELL__ < 902 +testBitWord8# :: Word# -> Int# -> Int# +testBitWord8# x# i# = (x# `and#` bitWord8# i#) `neWord#` 0## + +bitWord8# :: Int# -> Word# +bitWord8# i# = narrow8Word# (1## `uncheckedShiftL#` i#) + +orWord8# :: Word# -> Word# -> Word# +orWord8# = or# + +andWord8# :: Word# -> Word# -> Word# +andWord8# = and# + +complementWord8# :: Word# -> Word# +complementWord8# x# = x# `xor#` 0xff## + +subWord8# :: Word# -> Word# -> Word# +subWord8# = minusWord# + +uncheckedShiftLWord8# :: Word# -> Int# -> Word# +uncheckedShiftLWord8# = uncheckedShiftL# + +uncheckedShiftRLWord8# :: Word# -> Int# -> Word# +uncheckedShiftRLWord8# = uncheckedShiftRL# + +wordToWord8# :: Word# -> Word# +wordToWord8# x = x + +#else +testBitWord8# :: Word8# -> Int# -> Int# +testBitWord8# x# i# = (x# `andWord8#` bitWord8# i#) `neWord8#` wordToWord8# 0## + +bitWord8# :: Int# -> Word8# +bitWord8# i# = wordToWord8# 1## `uncheckedShiftLWord8#` i# + +complementWord8# :: Word8# -> Word8# +complementWord8# x# = x# `xorWord8#` wordToWord8# 0xff## +#endif + diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 0342f401c..a0bf661da 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -1,19 +1,21 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec --- Copyright : [2008..2020] The Accelerate Team +-- Copyright : [2008..2022] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell @@ -24,54 +26,66 @@ module Data.Primitive.Vec ( -- * SIMD vector types - Vec(..), - Vec2, pattern Vec2, - Vec3, pattern Vec3, - Vec4, pattern Vec4, - Vec8, pattern Vec8, - Vec16, pattern Vec16, - - listOfVec, + Vec(..), KnownNat, + V2, pattern V2, + V3, pattern V3, + V4, pattern V4, + V8, pattern V8, + V16, pattern V16, + + toList, fromList, + extract, insert, splat, liftVec, ) where +import Data.Array.Accelerate.Error + import Control.Monad.ST import Data.Primitive.ByteArray import Data.Primitive.Types import Language.Haskell.TH.Extra import Prettyprinter +import Prelude + +import qualified Foreign.Storable as Foreign import GHC.Base ( isTrue# ) import GHC.Int import GHC.Prim import GHC.TypeLits +import GHC.Types ( IO(..) ) import GHC.Word +import qualified GHC.Exts as GHC + +#if __GLASGOW_HASKELL__ < 808 +import GHC.Ptr +#endif -- Note: [Representing SIMD vector types] -- -- A simple polymorphic representation of SIMD types such as the following: -- --- > data Vec2 a = Vec2 !a !a +-- > data V2 a = V2 !a !a -- -- is not able to unpack the values into the constructor, meaning that --- 'Vec2' is storing pointers to (strict) values on the heap, which is +-- 'V2' is storing pointers to (strict) values on the heap, which is -- a very inefficient representation. -- -- We might try defining a data family instead so that we can get efficient -- unboxed representations, and even make use of the unlifted SIMD types GHC -- knows about: -- --- > data family Vec2 a :: * --- > data instance Vec2 Float = Vec2_Float Float# Float# -- reasonable --- > data instance Vec2 Double = Vec2_Double DoubleX2# -- built in! +-- > data family V2 a :: * +-- > data instance V2 Float = V2_Float Float# Float# -- reasonable +-- > data instance V2 Double = V2_Double DoubleX2# -- built in! -- -- However, this runs into the problem that GHC stores all values as word sized -- entities: -- --- > data instance Vec2 Int = Vec2_Int Int# Int# --- > data instance Vec2 Int8 = Vec2_Int8 Int8# Int8# -- Int8# does not exist; requires a full Int# +-- > data instance V2 Int = V2_Int Int# Int# +-- > data instance V2 Int8 = V2_Int8 Int8# Int8# -- Int8# does not exist; requires a full Int# -- -- which, again, is very memory inefficient. -- @@ -86,77 +100,151 @@ data Vec (n :: Nat) a = Vec ByteArray# type role Vec nominal representational instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where - show = vec . listOfVec + show = vec . toList where vec :: [a] -> String vec = show . group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " . map viaShow -listOfVec :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a] -listOfVec (Vec ba#) = go 0# +instance (Prim a, KnownNat n) => GHC.IsList (Vec n a) where + type Item (Vec n a) = a + {-# INLINE toList #-} + {-# INLINE fromList #-} + toList = toList + fromList = fromList + +instance (Foreign.Storable a, KnownNat n) => Foreign.Storable (Vec n a) where + {-# INLINE sizeOf #-} + {-# INLINE alignment #-} + {-# INLINE peek #-} + {-# INLINE poke #-} + + sizeOf _ = fromInteger (natVal' (proxy# :: Proxy# n)) * Foreign.sizeOf (undefined :: a) + alignment _ = Foreign.alignment (undefined :: a) + + peek (Ptr addr#) = + IO $ \s0 -> + case Foreign.sizeOf (undefined :: Vec n a) of { I# bytes# -> + case newAlignedPinnedByteArray# bytes# 16# s0 of { (# s1, mba# #) -> + case copyAddrToByteArray# addr# mba# 0# bytes# s1 of { s2 -> + case unsafeFreezeByteArray# mba# s2 of { (# s3, ba# #) -> + (# s3, Vec ba# #) + }}}} + + poke (Ptr addr#) (Vec ba#) = + IO $ \s0 -> + case Foreign.sizeOf (undefined :: Vec n a) of { I# bytes# -> + case copyByteArrayToAddr# ba# 0# addr# bytes# s0 of { + s1 -> (# s1, () #) + }} + +{-# INLINE toList #-} +toList :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a] +toList (Vec ba#) = go 0# where go :: Int# -> [a] go i# | isTrue# (i# <# n#) = indexByteArray# ba# i# : go (i# +# 1#) | otherwise = [] + -- + !(I# n#) = fromInteger (natVal' (proxy# :: Proxy# n)) - !(I# n#) = fromIntegral (natVal' (proxy# :: Proxy# n)) +{-# INLINE fromList #-} +fromList :: forall a n. (Prim a, KnownNat n) => [a] -> Vec n a +fromList xs = + case byteArrayFromListN (fromInteger (natVal' (proxy# :: Proxy# n))) xs of + ByteArray ba# -> Vec ba# instance Eq (Vec n a) where Vec ba1# == Vec ba2# = ByteArray ba1# == ByteArray ba2# +-- | Extract an element from a vector at the given index +-- +{-# INLINE extract #-} +extract :: forall n a. (Prim a, KnownNat n) => Vec n a -> Int -> a +extract (Vec ba#) i@(I# i#) = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + in boundsCheck "out of range" (i >= 0 && i < n) $ indexByteArray# ba# i# + +-- | Returns a new vector where the element at the specified index has been +-- replaced with the supplied value. +-- +{-# INLINE insert #-} +insert :: forall n a. (Prim a, KnownNat n) => Vec n a -> Int -> a -> Vec n a +insert (Vec ba#) i a = + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + bytes = n * sizeOf (undefined :: a) + in boundsCheck "out of range" (i >= 0 && i < n) + $ runST $ do + mba <- newByteArray bytes + copyByteArray mba 0 (ByteArray ba#) 0 bytes + writeByteArray mba i a + ByteArray ba'# <- unsafeFreezeByteArray mba + return $! Vec ba'# + +-- | Fill all lanes of a vector with the same value +-- +{-# INLINE splat #-} +splat :: forall n a. (Prim a, KnownNat n) => a -> Vec n a +splat x = runST $ do + let n = fromInteger (natVal' (proxy# :: Proxy# n)) + mba <- newByteArray (n * sizeOf (undefined :: a)) + setByteArray mba 0 n x + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + -- Type synonyms for common SIMD vector types -- -- Note that non-power-of-two sized SIMD vectors are a bit dubious, and -- special care must be taken in the code generator. For example, LLVM will --- treat a Vec3 with alignment of _4_, meaning that reads and writes will +-- treat a V3 with alignment of _4_, meaning that reads and writes will -- be (without further action) incorrect. -- -type Vec2 a = Vec 2 a -type Vec3 a = Vec 3 a -type Vec4 a = Vec 4 a -type Vec8 a = Vec 8 a -type Vec16 a = Vec 16 a - -pattern Vec2 :: Prim a => a -> a -> Vec2 a -pattern Vec2 a b <- (unpackVec2 -> (a,b)) - where Vec2 = packVec2 -{-# COMPLETE Vec2 #-} - -pattern Vec3 :: Prim a => a -> a -> a -> Vec3 a -pattern Vec3 a b c <- (unpackVec3 -> (a,b,c)) - where Vec3 = packVec3 -{-# COMPLETE Vec3 #-} - -pattern Vec4 :: Prim a => a -> a -> a -> a -> Vec4 a -pattern Vec4 a b c d <- (unpackVec4 -> (a,b,c,d)) - where Vec4 = packVec4 -{-# COMPLETE Vec4 #-} - -pattern Vec8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> Vec8 a -pattern Vec8 a b c d e f g h <- (unpackVec8 -> (a,b,c,d,e,f,g,h)) - where Vec8 = packVec8 -{-# COMPLETE Vec8 #-} - -pattern Vec16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> Vec16 a -pattern Vec16 a b c d e f g h i j k l m n o p <- (unpackVec16 -> (a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p)) - where Vec16 = packVec16 -{-# COMPLETE Vec16 #-} - -unpackVec2 :: Prim a => Vec2 a -> (a,a) +type V2 a = Vec 2 a +type V3 a = Vec 3 a +type V4 a = Vec 4 a +type V8 a = Vec 8 a +type V16 a = Vec 16 a + +pattern V2 :: Prim a => a -> a -> V2 a +pattern V2 a b <- (unpackVec2 -> (a,b)) + where V2 = packVec2 +{-# COMPLETE V2 #-} + +pattern V3 :: Prim a => a -> a -> a -> V3 a +pattern V3 a b c <- (unpackVec3 -> (a,b,c)) + where V3 = packVec3 +{-# COMPLETE V3 #-} + +pattern V4 :: Prim a => a -> a -> a -> a -> V4 a +pattern V4 a b c d <- (unpackVec4 -> (a,b,c,d)) + where V4 = packVec4 +{-# COMPLETE V4 #-} + +pattern V8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> V8 a +pattern V8 a b c d e f g h <- (unpackVec8 -> (a,b,c,d,e,f,g,h)) + where V8 = packVec8 +{-# COMPLETE V8 #-} + +pattern V16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> V16 a +pattern V16 a b c d e f g h i j k l m n o p <- (unpackVec16 -> (a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p)) + where V16 = packVec16 +{-# COMPLETE V16 #-} + +unpackVec2 :: Prim a => V2 a -> (a,a) unpackVec2 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# ) -unpackVec3 :: Prim a => Vec3 a -> (a,a,a) +unpackVec3 :: Prim a => V3 a -> (a,a,a) unpackVec3 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# , indexByteArray# ba# 2# ) -unpackVec4 :: Prim a => Vec4 a -> (a,a,a,a) +unpackVec4 :: Prim a => V4 a -> (a,a,a,a) unpackVec4 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# @@ -164,7 +252,7 @@ unpackVec4 (Vec ba#) = , indexByteArray# ba# 3# ) -unpackVec8 :: Prim a => Vec8 a -> (a,a,a,a,a,a,a,a) +unpackVec8 :: Prim a => V8 a -> (a,a,a,a,a,a,a,a) unpackVec8 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# @@ -176,7 +264,7 @@ unpackVec8 (Vec ba#) = , indexByteArray# ba# 7# ) -unpackVec16 :: Prim a => Vec16 a -> (a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a) +unpackVec16 :: Prim a => V16 a -> (a,a,a,a,a,a,a,a,a,a,a,a,a,a,a,a) unpackVec16 (Vec ba#) = ( indexByteArray# ba# 0# , indexByteArray# ba# 1# @@ -196,7 +284,7 @@ unpackVec16 (Vec ba#) = , indexByteArray# ba# 15# ) -packVec2 :: Prim a => a -> a -> Vec2 a +packVec2 :: Prim a => a -> a -> V2 a packVec2 a b = runST $ do mba <- newByteArray (2 * sizeOf a) writeByteArray mba 0 a @@ -204,7 +292,7 @@ packVec2 a b = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -packVec3 :: Prim a => a -> a -> a -> Vec3 a +packVec3 :: Prim a => a -> a -> a -> V3 a packVec3 a b c = runST $ do mba <- newByteArray (3 * sizeOf a) writeByteArray mba 0 a @@ -213,7 +301,7 @@ packVec3 a b c = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -packVec4 :: Prim a => a -> a -> a -> a -> Vec4 a +packVec4 :: Prim a => a -> a -> a -> a -> V4 a packVec4 a b c d = runST $ do mba <- newByteArray (4 * sizeOf a) writeByteArray mba 0 a @@ -223,7 +311,7 @@ packVec4 a b c d = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -packVec8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> Vec8 a +packVec8 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> V8 a packVec8 a b c d e f g h = runST $ do mba <- newByteArray (8 * sizeOf a) writeByteArray mba 0 a @@ -237,7 +325,7 @@ packVec8 a b c d e f g h = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# -packVec16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> Vec16 a +packVec16 :: Prim a => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> V16 a packVec16 a b c d e f g h i j k l m n o p = runST $ do mba <- newByteArray (16 * sizeOf a) writeByteArray mba 0 a @@ -259,6 +347,7 @@ packVec16 a b c d e f g h i j k l m n o p = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# + -- O(n) at runtime to copy from the Addr# to the ByteArray#. We should be able -- to do this without copying, but I don't think the definition of ByteArray# is -- exported (or it is deeply magical). diff --git a/src/GHC/TypeLits/Extra.hs b/src/GHC/TypeLits/Extra.hs new file mode 100644 index 000000000..9d1f8dfa2 --- /dev/null +++ b/src/GHC/TypeLits/Extra.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TypeOperators #-} +-- | +-- Module : GHC.TypeLits.Extra +-- Copyright : [2012..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- + +module GHC.TypeLits.Extra + where + +import Data.Typeable + +import GHC.Exts +import GHC.TypeLits +import Unsafe.Coerce + + +-- | We either get evidence that this function was instantiated with the same +-- type-level numbers, or 'Nothing'. +-- +{-# INLINEABLE sameNat' #-} +sameNat' :: (KnownNat n, KnownNat m) => Proxy# n -> Proxy# m -> Maybe (n :~: m) +sameNat' n m + | natVal' n == natVal' m = Just (unsafeCoerce Refl) -- same as 'GHC.TypeLits.sameNat' but for 'Proxy#' + | otherwise = Nothing + diff --git a/src/Language/Haskell/TH/Extra.hs b/src/Language/Haskell/TH/Extra.hs index 7688167fc..ab166f4d1 100644 --- a/src/Language/Haskell/TH/Extra.hs +++ b/src/Language/Haskell/TH/Extra.hs @@ -23,7 +23,7 @@ module Language.Haskell.TH.Extra ( #if MIN_VERSION_template_haskell(2,17,0) import Language.Haskell.TH hiding ( plainInvisTV, tupP, tupE ) #else -import Language.Haskell.TH hiding ( TyVarBndr, tupP, tupE ) +import Language.Haskell.TH hiding ( TyVarBndr, tupP, tupE, plainTV ) import Language.Haskell.TH.Syntax ( unTypeQ, unsafeTExpCoerce ) #if MIN_VERSION_template_haskell(2,16,0) import GHC.Exts ( RuntimeRep, TYPE ) @@ -71,6 +71,9 @@ tyVarBndrName :: TyVarBndr flag -> Name tyVarBndrName (PlainTV n) = n tyVarBndrName (KindedTV n _) = n +plainTV :: Name -> TyVarBndr () +plainTV = PlainTV + plainInvisTV :: Name -> Specificity -> TyVarBndr Specificity plainInvisTV n _ = PlainTV n