]> hydra-www.ietfng.org Git - dyna2/commitdiff
implementation of @nwf's argm idea for backpointers (#29). Uses `$key` instead
authorTim Vieira <tim.f.vieira@gmail.com>
Fri, 5 Jul 2013 03:01:15 +0000 (23:01 -0400)
committerTim Vieira <tim.f.vieira@gmail.com>
Fri, 5 Jul 2013 03:01:15 +0000 (23:01 -0400)
of `argm`.

BUGFIX: hash and eq for list wasn't correct.

examples/dijkstra-backpointers.dyna
examples/ptb.dyna
src/Dyna/Backend/Python/aggregator.py [moved from src/Dyna/Backend/Python/defn.py with 93% similarity]
src/Dyna/Backend/Python/chart.py
src/Dyna/Backend/Python/interpreter.py
src/Dyna/Backend/Python/post/trace.py
src/Dyna/Backend/Python/term.py

index 8c950b6a399357ae13485e49f7302a2d8cfd80d9..1d899484fd917799617b6444970220f425fb9315 100644 (file)
@@ -1,7 +1,7 @@
 % single source shortest path with optimal path extraction.
-path(start) min= 0.
-path(B) min= path(A) + edge(A,B).
-goal min= path(end).
+path(start) argmin= [0, start].
+path(B) argmin= [path(A) + edge(A,B), A].
+goal argmin= [path(end), end].
 
 % expensive path
 edge("a","b") := 1.
@@ -16,12 +16,9 @@ edge("e","d") := 1.
 start := "a".
 end := "d".
 
-% use argmin to determine backpointers
-b(V) argmin= [path(U) + edge(U,V), U].
-
 % extract cheapest path by following backpointers from each vertex.
 bestpath(start) := [start].
-bestpath(V) := U is b(V), [V | bestpath(U)].
+bestpath(V) := U is $key(&path(V)), [V | bestpath(U)] for V != start.
 
 % the optimal path is the one from the `end`.
 optimalpath = reverse(bestpath(end)).
index 9e7990e61f0768afadef4dea62e1bb7786df1405..a05bd99314c756ffd9da5883136448288142ab83 100644 (file)
@@ -61,14 +61,12 @@ b([X,Y,Z|Xs]) := [X, b(Y), b([&'@'(X), Z|Xs])].
 b([X,Y,Z]) := [X,b(Y),b(Z)].
 
 % CKY parser
-phrase(X,I,K) max= phrase(Y,I,J) * phrase(Z,J,K) * p(X,Y,Z).
-phrase(X,I,K) max= phrase(Y,I,K) * p(X,Y).
-phrase(X,I,I+1) max= 1 for [I,X] in enumerate(sentence).
+phrase(X,I,K) argmax= [phrase(Y,I,J) * phrase(Z,J,K) * p(X,Y,Z), [[Y,I,J], [Z,J,K]]].
+phrase(X,I,K) argmax= [phrase(Y,I,K) * p(X,Y), [[Y,I,K]]].
+phrase(X,I,I+1) argmax= [1, X] for [I,X] in enumerate(sentence).
 
 % backpointers
-bk(X,I,K) argmax= [phrase(Y,I,J) * phrase(Z,J,K) * p(X,Y,Z), [[Y,I,J], [Z,J,K]]].
-bk(X,I,K) argmax= [phrase(Y,I,K) * p(X,Y), [[Y,I,K]]].
-bk(X,I,I+1) argmax= [1, X] for [I,X] in enumerate(sentence).
+bk(X,I,K) = $key(&phrase(X,I,K)).
 
 % extract path from backpointers
 path(X,I,K) := W is bk(X,I,K), W.
similarity index 93%
rename from src/Dyna/Backend/Python/defn.py
rename to src/Dyna/Backend/Python/aggregator.py
index 87fd23c48e1094ea3666def402f87b8f8b8504b8..39a6ba455d30ec6a89b3a5a94f0f6dea3e26b960 100644 (file)
@@ -39,6 +39,7 @@ class Aggregator(object):
     def clear(self):
         pass
 
+NoAggregator = Aggregator()
 
 class BAggregator(Counter, Aggregator):
 #    def __init__(self):
@@ -146,22 +147,25 @@ class min_equals(BAggregator):
         if len(s):
             return min(s)
 
-
-class argmax_equals(max_equals):
+class maxwithkey_equals(max_equals):
     def fold(self):
         m = max_equals.fold(self)
-        if m:
+        self.key = None
+        if m is not None:
             if not hasattr(m, 'aslist') or len(m.aslist) != 2:
                 raise AggregatorError("argmax expects a pair of values")
-            return m.aslist[1]
+            self.key = m.aslist[1]
+            return m.aslist[0]
 
-class argmin_equals(min_equals):
+class minwithkey_equals(min_equals):
     def fold(self):
         m = min_equals.fold(self)
-        if m:
+        self.key = None
+        if m is not None:
             if not hasattr(m, 'aslist') or len(m.aslist) != 2:
                 raise AggregatorError("argmin expects a pair of values")
-            return m.aslist[1]
+            self.key = m.aslist[1]
+            return m.aslist[0]
 
 
 class plus_equals(BAggregator):
@@ -228,11 +232,11 @@ defs = {
     'set=': set_equals,
     'bag=': bag_equals,
     'mean=': mean_equals,
-    'argmax=': argmax_equals,
-    'argmin=': argmin_equals,
+    'argmax=': maxwithkey_equals,
+    'argmin=': minwithkey_equals,
 }
 
-def aggregator(name):
+def aggregator(name, term):
     "Create aggregator by ``name``."
 
     if name is None:
index 951f6e3687baaddfce76a0386011505e18b2b398..cda1156e2375c1ca26dfeb674ba64610eec53cf0 100644 (file)
@@ -1,5 +1,5 @@
 from collections import defaultdict
-from defn import aggregator
+from aggregator import aggregator
 from term import Term
 from utils import _repr
 
@@ -13,8 +13,8 @@ class Chart(object):
         self.ix = [defaultdict(set) for _ in xrange(arity)]
         self.agg_name = agg_name
 
-    def new_aggregator(self):
-        return aggregator(self.agg_name)
+    def new_aggregator(self, term):
+        return aggregator(self.agg_name, term)
 
     def __repr__(self):
         rows = [term for term in self.intern.values() if term.value is not None]
@@ -73,7 +73,7 @@ class Chart(object):
             return self.intern[args]
         except KeyError:
             self.intern[args] = term = Term(self.name, args)
-            term.aggregator = self.new_aggregator()
+            term.aggregator = self.new_aggregator(term)
             # index new term
             for i, x in enumerate(args):
                 self.ix[i][x].add(term)
index b81eb2890afb10e060df9866ca634fee5c9682e4..3d530927a4e5d8eb4da8ada3ac531920910d473a 100644 (file)
@@ -115,7 +115,6 @@ import load, post
 
 from term import Term, Cons, Nil
 from chart import Chart
-from defn import aggregator
 from utils import ip, red, green, blue, magenta, yellow, parse_attrs, \
     ddict, dynac, read_anf, strip_comments, _repr
 
@@ -235,10 +234,8 @@ class Interpreter(object):
         if nullary:
             print >> out
         for x in others:
-
             if x.startswith('$rule/'):
                 continue
-
             y = str(self.chart[x])   # skip empty chart
             if y:
                 print >> out, y
@@ -272,8 +269,7 @@ class Interpreter(object):
                     if i >= 5:
                         print >> out, '    %s more ...' % (len(I[r][etype]) - i)
                         break
-                    print >> out, '    when `%s` = %s' % (item, _repr(value))
-                    print >> out, '      %s' % (e)
+                    print >> out, '    `%s`: %s' % (item, e)
                 print >> out
 
         # errors pertaining to rules
@@ -403,6 +399,17 @@ class Interpreter(object):
             if was == now:
                 continue
 
+
+            # aggregator with special key
+            if hasattr(item.aggregator, 'key'):
+                key = self.build('$key/1', item)
+                if key.aggregator is None:
+                    from aggregator import aggregator
+                    key.aggregator = aggregator('=', key)
+                self.delete_emit(key, key.value, None, None)
+                self.emit(key, item.aggregator.key, None, None, delete=False)
+
+
             was_error = False
             if item in error:    # clear error
                 was_error = True
index e00797964713858852838692023a7b2220ee765c..5e8504feeb3fe427ecb9a56d0abfb96d261dc246 100644 (file)
@@ -7,7 +7,7 @@ TODO: have ANF output which functors are infix, prefix, nullary, etc.
 import re
 from collections import defaultdict
 
-import debug, defn
+import debug
 from draw_circuit import infer_edges
 from utils import yellow, green, cyan, red, _repr
 
@@ -134,7 +134,7 @@ class Crux(object):
     def format(self):
         rule = self.rule
         #src = rule.src.replace('\n',' ').strip()
-        #user_vars = dict(defn.user_vars(self.vs.items()))
+        #user_vars = dict(user_vars(self.vs.items()))
 
         graph = self.graph
         side = [self.get_function(x) for x in graph.outputs if x != rule.anf.result and x != rule.anf.head]
index 3d5e96df46ff010540bf9b005313b2de3c842a40..1c4a487eb338aed91c1b3a3adff09655482144dc 100644 (file)
@@ -1,6 +1,6 @@
 from errors import notimplemented
 from utils import _repr
-from defn import Aggregator
+from aggregator import NoAggregator
 
 
 # TODO: codegen should output a derived Term instance for each functor
@@ -22,14 +22,10 @@ class Term(object):
         return self.fn == other.fn and self.args == other.args
 
     def __cmp__(self, other):
-#        if self is other:
-#            return 0
         if other is None:
             return 1
         if not isinstance(other, Term):
             return 1
-#        if self == other:
-#            return 0
         return cmp((self.fn, self.args), (other.fn, other.args))
 
     def __repr__(self):
@@ -56,7 +52,7 @@ class Cons(Term):
         self.head = head
         self.tail = tail
         Term.__init__(self, 'cons/2', (head, tail))
-        self.aggregator = Aggregator()
+        self.aggregator = NoAggregator
         self.aslist = [self.head] + self.tail.aslist
 
     def __repr__(self):
@@ -72,6 +68,15 @@ class Cons(Term):
             else:
                 yield a, (None,), a
 
+    def __eq__(self, other):
+        try:
+            return self.aslist == other.aslist
+        except AttributeError:
+            return False
+
+    def __hash__(self):
+        return hash(tuple(self.aslist))
+
     def __cmp__(self, other):
         try:
             return cmp(self.aslist, other.aslist)
@@ -82,7 +87,7 @@ class Cons(Term):
 class _Nil(Term):
     def __init__(self):
         Term.__init__(self, 'nil/0', ())
-        self.aggregator = Aggregator()
+        self.aggregator = NoAggregator
         self.aslist = []
     def __repr__(self):
         return '[]'