From ff2b3765521d2c4a100d80ac6fc6aae8c23d8458 Mon Sep 17 00:00:00 2001 From: timv Date: Sun, 9 Jun 2013 15:42:56 -0400 Subject: [PATCH] small bugfixes in aggregators pertaining to when things should be null. --- src/Dyna/Backend/Python/defn.py | 35 +++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/src/Dyna/Backend/Python/defn.py b/src/Dyna/Backend/Python/defn.py index f8137b0..0c7c5cb 100644 --- a/src/Dyna/Backend/Python/defn.py +++ b/src/Dyna/Backend/Python/defn.py @@ -1,6 +1,8 @@ +# TODO: codegen should produce specialized Term with inc/dec methods baked +# in. This seems nicer than having a separate aggregator object. + import operator from collections import Counter -from utils import red class Aggregator(object): @@ -33,22 +35,36 @@ class BAggregator(Counter, Aggregator): assert False, "This method should never be called." -class LastEquals(BAggregator): +class PlusEquals(object): + __slots__ = 'pos', 'neg' + def __init__(self): + self.pos = 0 + self.neg = 0 + def inc(self, val, ruleix, variables): + self.pos += val + def dec(self, val, ruleix, variables): + self.neg += val + def fold(self): + return self.pos - self.neg + + +class ColonEquals(BAggregator): def inc(self, val, ruleix, variables): self[ruleix, val] += 1 def dec(self, val, ruleix, variables): self[ruleix, val] -= 1 def fold(self): - return max(ruleix for ruleix, cnt in self.iteritems() if cnt > 0)[1] + vs = [v for v, cnt in self.iteritems() if cnt > 0] + if vs: + return max(vs)[1] def user_vars(variables): "Post process the variables past to emit (which passes them to aggregator)." # remove the 'u' prefix on user variables 'uX' - # Note: We also ignore user variables with an underscore prefix - - return tuple((name[1:], val) for name, val in variables.items() if name.startswith('u') and not name.startswith('u_')) + return tuple((name[1:], val) for name, val in variables.items() + if name.startswith('u') and not name.startswith('u_')) class DictEquals(BAggregator): @@ -64,7 +80,7 @@ class DictEquals(BAggregator): self[val, vs] -= 1 def fold(self): - return list((x[0], dict(x[1])) for x, cnt in self.iteritems() if cnt > 0) + return list((x[0], dict(x[1])) for x, cnt in self.iteritems()) def majority_equals(a): @@ -144,7 +160,10 @@ def aggregator(name): return None if name == ':=': - return LastEquals(name, folder=None) + return ColonEquals(name, folder=None) + +# elif name == '+=': +# return PlusEquals() elif name == 'dict=': return DictEquals(name, folder=None) -- 2.50.1