]> hydra-www.ietfng.org Git - dyna2/commitdiff
small bugfixes in aggregators pertaining to when things should be null.
authortimv <tim.f.vieira@gmail.com>
Sun, 9 Jun 2013 19:42:56 +0000 (15:42 -0400)
committertimv <tim.f.vieira@gmail.com>
Sun, 9 Jun 2013 19:42:56 +0000 (15:42 -0400)
src/Dyna/Backend/Python/defn.py

index f8137b0ad72e66a0ae68e8b7eec8789c0108520a..0c7c5cb9bd8b51cc6de48f3e3df7790b44280def 100644 (file)
@@ -1,6 +1,8 @@
+# TODO: codegen should produce specialized Term with inc/dec methods baked
+# in. This seems nicer than having a separate aggregator object.
+
 import operator
 from collections import Counter
-from utils import red
 
 
 class Aggregator(object):
@@ -33,22 +35,36 @@ class BAggregator(Counter, Aggregator):
         assert False, "This method should never be called."
 
 
-class LastEquals(BAggregator):
+class PlusEquals(object):
+    __slots__ = 'pos', 'neg'
+    def __init__(self):
+        self.pos = 0
+        self.neg = 0
+    def inc(self, val, ruleix, variables):
+        self.pos += val
+    def dec(self, val, ruleix, variables):
+        self.neg += val
+    def fold(self):
+        return self.pos - self.neg
+
+
+class ColonEquals(BAggregator):
     def inc(self, val, ruleix, variables):
         self[ruleix, val] += 1
     def dec(self, val, ruleix, variables):
         self[ruleix, val] -= 1
     def fold(self):
-        return max(ruleix for ruleix, cnt in self.iteritems() if cnt > 0)[1]
+        vs = [v for v, cnt in self.iteritems() if cnt > 0]
+        if vs:
+            return max(vs)[1]
 
 
 def user_vars(variables):
     "Post process the variables past to emit (which passes them to aggregator)."
     # remove the 'u' prefix on user variables 'uX'
-
     # Note: We also ignore user variables with an underscore prefix
-
-    return tuple((name[1:], val) for name, val in variables.items() if name.startswith('u') and not name.startswith('u_'))
+    return tuple((name[1:], val) for name, val in variables.items()
+                 if name.startswith('u') and not name.startswith('u_'))
 
 
 class DictEquals(BAggregator):
@@ -64,7 +80,7 @@ class DictEquals(BAggregator):
         self[val, vs] -= 1
 
     def fold(self):
-        return list((x[0], dict(x[1])) for x, cnt in self.iteritems() if cnt > 0)
+        return list((x[0], dict(x[1])) for x, cnt in self.iteritems())
 
 
 def majority_equals(a):
@@ -144,7 +160,10 @@ def aggregator(name):
         return None
 
     if name == ':=':
-        return LastEquals(name, folder=None)
+        return ColonEquals(name, folder=None)
+
+#    elif name == '+=':
+#        return PlusEquals()
 
     elif name == 'dict=':
         return DictEquals(name, folder=None)