module Flatten(Graph(In,Node) ,compileDict ,inputs,outputs,temps ,Binding(Process,Input,Output) -- ,dictExtend,nodes -- for FlattenMonad ) where -- import List import Data.List -- Graph flattening based on Eq. Converts a tree/graph type.. data Graph n v = In v | Tmp v | Node n [Graph n v] deriving (Eq, Show) -- .. into a flat dictionary type .. data Binding n v = Process v n [v] | Input v v | Output v v deriving (Eq, Show) -- .. based on externally provided data types: -- v = variable -- n = node metadata (i.e. opcode) -- Sequence all nodes in the graph in postorder. This allows left -- fold over this list to accumulate the dictionary, starting from -- root dependencies. nodes :: Graph n v -> [Graph n v] nodes n@(Node _ gs) = (concat $ map nodes gs) ++ [n] nodes i@(In _) = [i] nodes t@(Tmp _) = error "Input contains temp node." -- Foldable version: collect unique nodes based on (==) dictExtend :: (Eq t) => ([a], [(t, a)]) -> t -> ([a], [(t, a)]) dictExtend s@((tmp:tmps), dict) node = case find (\(n, _) -> (n == node)) dict of Just (_, name) -> s Nothing -> (tmps, ((node, tmp):dict)) -- First step of graph flattening performs a fold of dictExtend for -- each output, building a (node, name) assoc list. nameNodes :: (Eq v, Eq n) => [v] -> [(Graph n v, v)] -> [(Graph n v, v)] nameNodes tmps namedExprs = dict where -- Update node table for one output by gathering nodes for each -- named output, and naming the output node. update state (expr, name) = state' where (tmps', dict@((node,tmp'):_)) = foldl' dictExtend state (nodes expr) state' = (tmps', ((Tmp tmp', name):dict)) -- Start of with empty dict and recover dict from end state. (_,dict) = foldl' update (tmps,[]) namedExprs flattenNodes dict = map flatten (reverse dict) where flatten (Node opcode ns, tmp) = Process tmp opcode (map (rename dict) ns) flatten (In mem, tmp) = Input tmp mem flatten (Tmp tmp, mem) = Output mem tmp rename :: (Eq n, Eq v) => [(Graph n v, v)] -> Graph n v -> v rename dict = f where f node@(Node _ _) = (\(Just x)->x) $ lookup node dict f (In var) = var f (Tmp var) = var -- Complilation consists of two passes: construction of a map from -- nodes to names, and flattening of the dictionary using those names. compileDict :: (Eq n, Eq v) => [v] -> [(Graph n v,v)] -> [Binding n v] compileDict tmps = flattenNodes . (nameNodes tmps) -- Query nodes in a binding list. inputs :: [Binding n v] -> [v] outputs :: [Binding n v] -> [v] temps :: [Binding n v] -> [v] inputs = foldr f [] where f (Input _ i) is = i:is f _ is = is outputs = foldr f [] where f (Output o _) os = o:os f _ os = os temps = foldr f [] where f (Input r _) rs = r:rs f (Process r _ _) rs = r:rs f _ rs = rs