]> hydra-www.ietfng.org Git - dyna2/commitdiff
new aggregator interface.
authortimv <tim.f.vieira@gmail.com>
Tue, 18 Dec 2012 23:24:30 +0000 (18:24 -0500)
committertimv <tim.f.vieira@gmail.com>
Tue, 18 Dec 2012 23:24:30 +0000 (18:24 -0500)
tweaks to constants in Python.hs

bin/defn.py
bin/interpreter.py
examples/dijkstra.dyna
examples/papa2.dyna
src/Dyna/Analysis/ANF.hs
src/Dyna/Backend/Python.hs

index 23d189e9ee27de2e1a393cf2143be85d5ccc2492..aac38fcd342eb34e3a032e928043f647e69ec808 100644 (file)
@@ -26,67 +26,116 @@ from collections import defaultdict, Counter
 from utils import red
 
 
-def agg_bind(agg, agg_decl, table):
+class Aggregator(object):
+    def __init__(self, item, name):
+        self.item = item
+        self.name = name
+    def fold(self):
+        raise NotImplementedError
+    def inc(self, val):
+        raise NotImplementedError
+    def dec(self, val):
+        raise NotImplementedError
+    def clear(self):
+        raise NotImplementedError
+    def __repr__(self):
+        return 'Aggregator(%r, %r)' % (self.item, self.name)
+
+
+class MultisetAggregator(Counter, Aggregator):
+    def __init__(self, item, name, folder):
+        self.folder = folder
+        Aggregator.__init__(self, item, name)
+        Counter.__init__(self)
+    def inc(self, val):
+        self[val] += 1
+    def dec(self, val):
+        self[val] -= 1
+    def fold(self):
+        return self.folder(self)
+    def fromkeys(self, *_):
+        raise NotImplementedError
+
+
+class LastEquals(Aggregator):
+    def __init__(self, item, name):
+        self.list = []
+        Aggregator.__init__(self, item, name)
+    def inc(self, val):
+        self.list.append(val)
+    def dec(self, val):
+        raise NotImplementedError('dec on last equal not defined.')
+    def fold(self):
+        return self.list[-1]
+
+
+def agg_bind(item, agg_decl):
     """
     Bind declarations (map functor->string) to table (storing values) and
     aggregator definition (the fold funciton, which gets executed).
     """
 
-    def max_equals(item):
-        s = [k for k, m in table[item].iteritems() if m > 0]
+    def majority_equals(a):
+        [(k,_)] = a.most_common(1)
+        return k
+
+    def max_equals(a):
+        s = [k for k, m in a.iteritems() if m > 0]
         if len(s):
             return max(s)
 
-    def min_equals(item):
-        s = [k for k, m in table[item].iteritems() if m > 0]
+    def min_equals(a):
+        s = [k for k, m in a.iteritems() if m > 0]
         if len(s):
             return min(s)
 
-    def plus_equals(item):
-        s = [k*m for k, m in table[item].iteritems() if m != 0]
+    def plus_equals(a):
+        s = [k*m for k, m in a.iteritems() if m != 0]
         if len(s):
             return reduce(operator.add, s)
 
-    def times_equals(item):
-        s = [k**m for k, m in table[item].iteritems() if m != 0]
+    def times_equals(a):
+        s = [k**m for k, m in a.iteritems() if m != 0]
         if len(s):
             return reduce(operator.mul, s)
 
-    def and_equals(item):
-        s = [k for k, m in table[item].iteritems() if m > 0]
+    def and_equals(a):
+        s = [k for k, m in a.iteritems() if m > 0]
         if len(s):
             return reduce(lambda x,y: x and y, s)
 
-    def or_equals(item):
-        s = [k for k, m in table[item].iteritems() if m > 0]
+    def or_equals(a):
+        s = [k for k, m in a.iteritems() if m > 0]
         if len(s):
             return reduce(lambda x,y: x or y, s)
 
+    def b_and_equals(a):
+        s = [k for k, m in a.iteritems() if m > 0]
+        if len(s):
+            return reduce(operator.and_, s)
+
+    def b_or_equals(a):
+        s = [k for k, m in a.iteritems() if m > 0]
+        if len(s):
+            return reduce(operator.or_, s)
 
     # map names to functions
-    agg_defs = {
+    defs = {
         'max=': max_equals,
         'min=': min_equals,
         '+=': plus_equals,
         '*=': times_equals,
-        '&=': and_equals,
-        '|=': or_equals,
+        'and=': and_equals,
+        'or=': or_equals,
+        '&=': b_and_equals,
+        '|=': b_or_equals,
         ':-': or_equals,
+        'majority=': majority_equals,
     }
 
-    # commit functors to an aggregator definition to avoid unnecessary lookups.
-    for fn in agg_decl:
-
-#        if agg_decl[fn] == ':=':   # XXX: leaves previous version???
-#            raise NotImplementedError("aggregator ':=' not implemented yet.")
-#            continue
-
-        if fn in agg:
-            if agg[fn].__name__ != agg_defs[agg_decl[fn]].__name__:
-                print red % 'conflicting aggregators. %s and %s' % (agg[fn], agg_defs[agg_decl[fn]])
+    (fn, _) = item
 
-        # XXX: ignores conflicts with aggregator, might lead to confusion.
+    if agg_decl[fn] == ':=':
+        return LastEquals(item, agg_decl[fn])
 
-        # TODO: as soon as we ':=' we probably won't want this behavior and we
-        # can restor the assertion that aggregator hasn't changed.
-        agg[fn] = agg_defs[agg_decl[fn]]
+    return MultisetAggregator(item, agg_decl[fn], defs[agg_decl[fn]])
index 4832800d39d59b0196c80a3896bd73f99baa52e2..ef54141c593139e0ebe440e6b0b778aad68b15b9 100644 (file)
@@ -21,6 +21,8 @@
  - TODO: deletion of a rule should be running the initializer for the rule in
    deletion mode.
 
+ - TODO: hooks from introspection, eval, and prioritization.
+
 """
 
 #from debug import ultraTB2; ultraTB2.enable()
@@ -30,7 +32,7 @@ import os, sys
 from collections import defaultdict, Counter
 from argparse import ArgumentParser
 from utils import ip, red, green, blue, magenta
-from defn import agg_bind, call
+from defn import agg_bind
 
 
 # TODO: as soon as we have safe names for these things we can get rid of this.
@@ -42,10 +44,20 @@ class chart_indirect(dict):
         return c
 
 
+class aggregator_indirect(dict):
+    def __missing__(self, item):
+        a = agg_bind(item, agg_decl)
+        self[item] = a
+        return a
+
+
+aggregator = aggregator_indirect()
+
+
 chart = chart_indirect()
 _delete = False
 agenda = set()
-aggregator = defaultdict(Counter)
+#aggregator = defaultdict(Counter)
 agg = {}
 agg_decl = None # filled in after exec, this only here to satisfy lint checker.
 
@@ -156,8 +168,8 @@ def prettify(x):
         raise ValueError("Don't know what to do with %r" % x)
 
 
-# Update handler indirection -- a true hack. Allow us to have many handlers on
-# the same functor/arity
+# Update handler indirection -- a temporary hack. Allow us to have many handlers
+# on the same functor/arity. Eventually, we'll fuse handlers into one handler.
 
 def register(fn):
     """
@@ -238,9 +250,9 @@ def emit(item, val):
         % 'emit %s (val %s; curr: %s)' % (pretty(item), val, lookup(item))
 
     if _delete:
-        aggregator[item][val] -= 1
+        aggregator[item].dec(val)
     else:
-        aggregator[item][val] += 1
+        aggregator[item].inc(val)
 
     agenda.add(item)
 
@@ -255,11 +267,6 @@ def delete(item, val):
     _delete = False
 
 
-def aggregate(item):
-    (fn, _) = item
-    return agg[fn](item)
-
-
 def lookup(item):
     (fn, idx) = item
     return chart[fn].data[idx][-1]
@@ -275,7 +282,7 @@ def _go():
         print 'pop', pretty(item),
 
         was = lookup(item)
-        now = aggregate(item)
+        now = aggregator[item].fold()
 
         print 'was %s, now %s' % (was, now)
 
@@ -318,9 +325,6 @@ def load(f, verbose=True):
     # load generated code.
     execfile(f, globals())
 
-    # bind aggregators to definitions
-    agg_bind(agg, agg_decl, aggregator)
-
 
 def dump(code, filename='/tmp/tmp.dyna'):
     "Write code to file."
index 48673db639320a1ea0ad0769094df1d47f3e63b1..980afbff0d1d083133a9ddc2a757d2861b052e66 100644 (file)
@@ -3,13 +3,13 @@
 path(*start) min= 0 .
 path(B) min= path(A) + edge(A,B).
 
-start += "a".
+start := "a".
 
-edge("a","b") += 1 .
-edge("b","d") += 1 .
-edge("a","d") += 3 .
-edge("a","c") += 1 .
-edge("c","d") += 2 .
+edge("a","b") := 1.
+edge("b","d") := 1.
+edge("a","d") := 3.
+edge("a","c") := 1.
+edge("c","d") := 2.
 
 % Expected
 %  path("a") = 0
index 7f8a833a5ceeba094b2616a1dfcbc9a463147762..790583fdcac947f8768b6b6020cd4492e1c993e8 100644 (file)
@@ -1,13 +1,19 @@
 % Parsing a simple sentence.
 
 % CKY-like parsing
-phrase(X,I,K,t(X,TY)) += phrase(Y,I,K,TY) * rewrite(X,Y).
-phrase(X,I,K,t(X,TY,TZ)) += phrase(Y,I,J,TY) * phrase(Z,J,K,TZ) * rewrite(X,Y,Z).
+phrase(X,I,K,t(X,TY)) max= phrase(Y,I,K,TY) * rewrite(X,Y).
+phrase(X,I,K,t(X,TY,TZ)) max= phrase(Y,I,J,TY) * phrase(Z,J,K,TZ) * rewrite(X,Y,Z).
 
-goal(P) += phrase("S", 0, *length, P).
+goal(P) max= phrase("S", 0, *length, P).
+
+best max= pair(*phrase("S", 0, *length, P), P).
+
+bestScore max= pair(Score,_) is best, Score.
+bestParse max= pair(_,P) is best, P.
 
 length max= word(_, I), I+1.
 
+
 % grammar rules
 rewrite( "S",   "S",  ".") += 1.
 rewrite( "S",  "NP", "VP") += 1.
@@ -27,13 +33,13 @@ rewrite("Det",      "a") += 1.
 
 % sentence
 % "Papa at the caviar with the spoon ."
-word(  "Papa", 0).
-word(   "ate", 1).
-word(   "the", 2).
-word("caviar", 3).
-word(  "with", 4).
-word(     "a", 5).
-word( "spoon", 6).
-word(     ".", 7).
-
-phrase(W, I, I+1, W) += word(W, I), 1.
+word(  "Papa", 0) := true.
+word(   "ate", 1) := true.
+word(   "the", 2) := true.
+word("caviar", 3) := true.
+word(  "with", 4) := true.
+word(     "a", 5) := true.
+word( "spoon", 6) := true.
+word(     ".", 7) := true.
+
+phrase(W, I, I+1, W) max= word(W, I), 1.
index 695dac2f631f2e1de6f52a4a2b09d4973e8aff3e..cbe15b6e133a7e2be8b57b1ba199e42930856c5d 100644 (file)
@@ -206,6 +206,7 @@ dynaFunctorArgDispositions x = case x of
 dynaFunctorSelfDispositions :: (DFunct,Int) -> SelfDispos
 dynaFunctorSelfDispositions x = case x of
     ("pair",2)   -> SDQuote
+    ("eval",1)   -> SDEval
     (name, _) ->
        -- If it starts with a nonalpha, it prefers to evaluate
        let d = if C.isAlphaNum $ head $ BU.toString name
index d094a183f56079c6003dc0868c1a08e89698966a..aa225f9cb8842001354b137953d9cecb5b431860 100644 (file)
@@ -42,6 +42,7 @@ import qualified Text.Trifecta              as T
 constants = S.fromList
     [ ("+",2)
     , ("-",2)
+    , ("-",1)    -- unary negation
     , ("*",2)
     , ("/",2)
     , ("^",2)
@@ -53,7 +54,7 @@ constants = S.fromList
     , ("<=",2)
     , (">",2)
     , (">=",2)
-    , ("!",2)
+    , ("!",1)
     , ("mod",1)
     , ("abs",1)
     , ("log",1)
@@ -61,6 +62,7 @@ constants = S.fromList
     , ("and",2)
     , ("or",2)
     , ("not",1)
+    , ("eval",1)
     , ("true",0)
     , ("false",0)
     , ("null",0)   -- XXX is this right?
@@ -170,6 +172,7 @@ pycall table f vs = case (f, length vs) of
   ( "log", 1) -> call "log"
   ( "exp", 1) -> call "exp"
   (   "!", 1) -> call "not"
+  (   "-", 1) -> call "-"
 
   ( "null", 0) -> "None"
   ( "true", 0) -> "True"