From d9edbdfaa365665d1de5ae7caa857b3d27e513c8 Mon Sep 17 00:00:00 2001 From: Tim Vieira Date: Thu, 4 Jul 2013 23:01:15 -0400 Subject: [PATCH] implementation of @nwf's argm idea for backpointers (#29). Uses `$key` instead of `argm`. BUGFIX: hash and eq for list wasn't correct. --- examples/dijkstra-backpointers.dyna | 11 ++++----- examples/ptb.dyna | 10 ++++---- .../Backend/Python/{defn.py => aggregator.py} | 24 +++++++++++-------- src/Dyna/Backend/Python/chart.py | 8 +++---- src/Dyna/Backend/Python/interpreter.py | 17 +++++++++---- src/Dyna/Backend/Python/post/trace.py | 4 ++-- src/Dyna/Backend/Python/term.py | 19 +++++++++------ 7 files changed, 52 insertions(+), 41 deletions(-) rename src/Dyna/Backend/Python/{defn.py => aggregator.py} (93%) diff --git a/examples/dijkstra-backpointers.dyna b/examples/dijkstra-backpointers.dyna index 8c950b6..1d89948 100644 --- a/examples/dijkstra-backpointers.dyna +++ b/examples/dijkstra-backpointers.dyna @@ -1,7 +1,7 @@ % single source shortest path with optimal path extraction. -path(start) min= 0. -path(B) min= path(A) + edge(A,B). -goal min= path(end). +path(start) argmin= [0, start]. +path(B) argmin= [path(A) + edge(A,B), A]. +goal argmin= [path(end), end]. % expensive path edge("a","b") := 1. @@ -16,12 +16,9 @@ edge("e","d") := 1. start := "a". end := "d". -% use argmin to determine backpointers -b(V) argmin= [path(U) + edge(U,V), U]. - % extract cheapest path by following backpointers from each vertex. bestpath(start) := [start]. -bestpath(V) := U is b(V), [V | bestpath(U)]. +bestpath(V) := U is $key(&path(V)), [V | bestpath(U)] for V != start. % the optimal path is the one from the `end`. optimalpath = reverse(bestpath(end)). diff --git a/examples/ptb.dyna b/examples/ptb.dyna index 9e7990e..a05bd99 100644 --- a/examples/ptb.dyna +++ b/examples/ptb.dyna @@ -61,14 +61,12 @@ b([X,Y,Z|Xs]) := [X, b(Y), b([&'@'(X), Z|Xs])]. b([X,Y,Z]) := [X,b(Y),b(Z)]. % CKY parser -phrase(X,I,K) max= phrase(Y,I,J) * phrase(Z,J,K) * p(X,Y,Z). -phrase(X,I,K) max= phrase(Y,I,K) * p(X,Y). -phrase(X,I,I+1) max= 1 for [I,X] in enumerate(sentence). +phrase(X,I,K) argmax= [phrase(Y,I,J) * phrase(Z,J,K) * p(X,Y,Z), [[Y,I,J], [Z,J,K]]]. +phrase(X,I,K) argmax= [phrase(Y,I,K) * p(X,Y), [[Y,I,K]]]. +phrase(X,I,I+1) argmax= [1, X] for [I,X] in enumerate(sentence). % backpointers -bk(X,I,K) argmax= [phrase(Y,I,J) * phrase(Z,J,K) * p(X,Y,Z), [[Y,I,J], [Z,J,K]]]. -bk(X,I,K) argmax= [phrase(Y,I,K) * p(X,Y), [[Y,I,K]]]. -bk(X,I,I+1) argmax= [1, X] for [I,X] in enumerate(sentence). +bk(X,I,K) = $key(&phrase(X,I,K)). % extract path from backpointers path(X,I,K) := W is bk(X,I,K), W. diff --git a/src/Dyna/Backend/Python/defn.py b/src/Dyna/Backend/Python/aggregator.py similarity index 93% rename from src/Dyna/Backend/Python/defn.py rename to src/Dyna/Backend/Python/aggregator.py index 87fd23c..39a6ba4 100644 --- a/src/Dyna/Backend/Python/defn.py +++ b/src/Dyna/Backend/Python/aggregator.py @@ -39,6 +39,7 @@ class Aggregator(object): def clear(self): pass +NoAggregator = Aggregator() class BAggregator(Counter, Aggregator): # def __init__(self): @@ -146,22 +147,25 @@ class min_equals(BAggregator): if len(s): return min(s) - -class argmax_equals(max_equals): +class maxwithkey_equals(max_equals): def fold(self): m = max_equals.fold(self) - if m: + self.key = None + if m is not None: if not hasattr(m, 'aslist') or len(m.aslist) != 2: raise AggregatorError("argmax expects a pair of values") - return m.aslist[1] + self.key = m.aslist[1] + return m.aslist[0] -class argmin_equals(min_equals): +class minwithkey_equals(min_equals): def fold(self): m = min_equals.fold(self) - if m: + self.key = None + if m is not None: if not hasattr(m, 'aslist') or len(m.aslist) != 2: raise AggregatorError("argmin expects a pair of values") - return m.aslist[1] + self.key = m.aslist[1] + return m.aslist[0] class plus_equals(BAggregator): @@ -228,11 +232,11 @@ defs = { 'set=': set_equals, 'bag=': bag_equals, 'mean=': mean_equals, - 'argmax=': argmax_equals, - 'argmin=': argmin_equals, + 'argmax=': maxwithkey_equals, + 'argmin=': minwithkey_equals, } -def aggregator(name): +def aggregator(name, term): "Create aggregator by ``name``." if name is None: diff --git a/src/Dyna/Backend/Python/chart.py b/src/Dyna/Backend/Python/chart.py index 951f6e3..cda1156 100644 --- a/src/Dyna/Backend/Python/chart.py +++ b/src/Dyna/Backend/Python/chart.py @@ -1,5 +1,5 @@ from collections import defaultdict -from defn import aggregator +from aggregator import aggregator from term import Term from utils import _repr @@ -13,8 +13,8 @@ class Chart(object): self.ix = [defaultdict(set) for _ in xrange(arity)] self.agg_name = agg_name - def new_aggregator(self): - return aggregator(self.agg_name) + def new_aggregator(self, term): + return aggregator(self.agg_name, term) def __repr__(self): rows = [term for term in self.intern.values() if term.value is not None] @@ -73,7 +73,7 @@ class Chart(object): return self.intern[args] except KeyError: self.intern[args] = term = Term(self.name, args) - term.aggregator = self.new_aggregator() + term.aggregator = self.new_aggregator(term) # index new term for i, x in enumerate(args): self.ix[i][x].add(term) diff --git a/src/Dyna/Backend/Python/interpreter.py b/src/Dyna/Backend/Python/interpreter.py index b81eb28..3d53092 100644 --- a/src/Dyna/Backend/Python/interpreter.py +++ b/src/Dyna/Backend/Python/interpreter.py @@ -115,7 +115,6 @@ import load, post from term import Term, Cons, Nil from chart import Chart -from defn import aggregator from utils import ip, red, green, blue, magenta, yellow, parse_attrs, \ ddict, dynac, read_anf, strip_comments, _repr @@ -235,10 +234,8 @@ class Interpreter(object): if nullary: print >> out for x in others: - if x.startswith('$rule/'): continue - y = str(self.chart[x]) # skip empty chart if y: print >> out, y @@ -272,8 +269,7 @@ class Interpreter(object): if i >= 5: print >> out, ' %s more ...' % (len(I[r][etype]) - i) break - print >> out, ' when `%s` = %s' % (item, _repr(value)) - print >> out, ' %s' % (e) + print >> out, ' `%s`: %s' % (item, e) print >> out # errors pertaining to rules @@ -403,6 +399,17 @@ class Interpreter(object): if was == now: continue + + # aggregator with special key + if hasattr(item.aggregator, 'key'): + key = self.build('$key/1', item) + if key.aggregator is None: + from aggregator import aggregator + key.aggregator = aggregator('=', key) + self.delete_emit(key, key.value, None, None) + self.emit(key, item.aggregator.key, None, None, delete=False) + + was_error = False if item in error: # clear error was_error = True diff --git a/src/Dyna/Backend/Python/post/trace.py b/src/Dyna/Backend/Python/post/trace.py index e007979..5e8504f 100644 --- a/src/Dyna/Backend/Python/post/trace.py +++ b/src/Dyna/Backend/Python/post/trace.py @@ -7,7 +7,7 @@ TODO: have ANF output which functors are infix, prefix, nullary, etc. import re from collections import defaultdict -import debug, defn +import debug from draw_circuit import infer_edges from utils import yellow, green, cyan, red, _repr @@ -134,7 +134,7 @@ class Crux(object): def format(self): rule = self.rule #src = rule.src.replace('\n',' ').strip() - #user_vars = dict(defn.user_vars(self.vs.items())) + #user_vars = dict(user_vars(self.vs.items())) graph = self.graph side = [self.get_function(x) for x in graph.outputs if x != rule.anf.result and x != rule.anf.head] diff --git a/src/Dyna/Backend/Python/term.py b/src/Dyna/Backend/Python/term.py index 3d5e96d..1c4a487 100644 --- a/src/Dyna/Backend/Python/term.py +++ b/src/Dyna/Backend/Python/term.py @@ -1,6 +1,6 @@ from errors import notimplemented from utils import _repr -from defn import Aggregator +from aggregator import NoAggregator # TODO: codegen should output a derived Term instance for each functor @@ -22,14 +22,10 @@ class Term(object): return self.fn == other.fn and self.args == other.args def __cmp__(self, other): -# if self is other: -# return 0 if other is None: return 1 if not isinstance(other, Term): return 1 -# if self == other: -# return 0 return cmp((self.fn, self.args), (other.fn, other.args)) def __repr__(self): @@ -56,7 +52,7 @@ class Cons(Term): self.head = head self.tail = tail Term.__init__(self, 'cons/2', (head, tail)) - self.aggregator = Aggregator() + self.aggregator = NoAggregator self.aslist = [self.head] + self.tail.aslist def __repr__(self): @@ -72,6 +68,15 @@ class Cons(Term): else: yield a, (None,), a + def __eq__(self, other): + try: + return self.aslist == other.aslist + except AttributeError: + return False + + def __hash__(self): + return hash(tuple(self.aslist)) + def __cmp__(self, other): try: return cmp(self.aslist, other.aslist) @@ -82,7 +87,7 @@ class Cons(Term): class _Nil(Term): def __init__(self): Term.__init__(self, 'nil/0', ()) - self.aggregator = Aggregator() + self.aggregator = NoAggregator self.aslist = [] def __repr__(self): return '[]' -- 2.50.1