]> hydra-www.ietfng.org Git - dyna2/commitdiff
Refactored: Term's store reference to their aggregator.
authortimv <tim.f.vieira@gmail.com>
Mon, 3 Jun 2013 01:52:57 +0000 (21:52 -0400)
committertimv <tim.f.vieira@gmail.com>
Mon, 3 Jun 2013 01:52:57 +0000 (21:52 -0400)
examples/matrixops.dyna
src/Dyna/Backend/Python/defn.py
src/Dyna/Backend/Python/interpreter.py

index aadab380530dbbcc34a28177cf23ea9c88f0c3fd..62dbc0163971e2c94c94630a30ef417d55634869 100644 (file)
@@ -34,4 +34,4 @@ m(b, 2, 3) += 0 .
 
 % matrix "c" is the product of matricies "a" and "b"
 :-dispos product(&,&).
-product(a,b) += &c .
+product(a,b) := &c .
index 66c330defa24734b1cf5986a3c8396c5fc493763..f3267024f9c79d1b56d5deeabf59f2652ae020d0 100644 (file)
@@ -4,8 +4,7 @@ from utils import red
 
 
 class Aggregator(object):
-    def __init__(self, item, name):
-        self.item = item
+    def __init__(self, name):
         self.name = name
     def fold(self):
         raise NotImplementedError
@@ -16,12 +15,12 @@ class Aggregator(object):
     def clear(self):
         raise NotImplementedError
     def __repr__(self):
-        return 'Aggregator(%r, %r)' % (self.item, self.name)
+        return 'Aggregator(%r)' % (self.name)
 
 
 class BAggregator(Counter, Aggregator):
-    def __init__(self, item, name):
-        Aggregator.__init__(self, item, name)
+    def __init__(self, name):
+        Aggregator.__init__(self, name)
         Counter.__init__(self)
     def inc(self, val, ruleix, variables):
         self[val] += 1
@@ -34,9 +33,9 @@ class BAggregator(Counter, Aggregator):
 
 
 class MultisetAggregator(BAggregator):
-    def __init__(self, item, name, folder):
+    def __init__(self, name, folder):
         self.folder = folder
-        BAggregator.__init__(self, item, name)
+        BAggregator.__init__(self, name)
     def fold(self):
         return self.folder(self)
 
@@ -51,9 +50,9 @@ class LastEquals(BAggregator):
 
 
 class SetEquals(Aggregator):
-    def __init__(self, item, name):
+    def __init__(self, name):
         self.set = set([])
-        Aggregator.__init__(self, item, name)
+        Aggregator.__init__(self, name)
     def inc(self, val, ruleix, variables):
         self.set.add(val)
     def dec(self, val, ruleix, variables):
@@ -64,79 +63,81 @@ class SetEquals(Aggregator):
         self.set.clear()
 
 
-def agg_bind(item, agg_decl):
-    """
-    Bind declarations (map functor->string) to table (storing values) and
-    aggregator definition (the fold funciton, which gets executed).
-    """
-
-    def majority_equals(a):
-        [(k,_)] = a.most_common(1)
-        return k
-
-    def max_equals(a):
-        s = [k for k, m in a.iteritems() if m > 0]
-        if len(s):
-            return max(s)
-
-    def min_equals(a):
-        s = [k for k, m in a.iteritems() if m > 0]
-        if len(s):
-            return min(s)
-
-    def plus_equals(a):
-        s = [k*m for k, m in a.iteritems() if m != 0]
-        if len(s):
-            return reduce(operator.add, s)
-
-    def times_equals(a):
-        s = [k**m for k, m in a.iteritems() if m != 0]
-        if len(s):
-            return reduce(operator.mul, s)
-
-    def and_equals(a):
-        s = [k for k, m in a.iteritems() if m > 0]
-        if len(s):
-            return reduce(lambda x,y: x and y, s)
-
-    def or_equals(a):
-        s = [k for k, m in a.iteritems() if m > 0]
-        if len(s):
-            return reduce(lambda x,y: x or y, s)
-
-    def b_and_equals(a):
-        s = [k for k, m in a.iteritems() if m > 0]
-        if len(s):
-            return reduce(operator.and_, s)
-
-    def b_or_equals(a):
-        s = [k for k, m in a.iteritems() if m > 0]
-        if len(s):
-            return reduce(operator.or_, s)
-
-    # map names to functions
-    defs = {
-        'max=': max_equals,
-        'min=': min_equals,
-        '+=': plus_equals,
-        '*=': times_equals,
-        'and=': and_equals,
-        'or=': or_equals,
-        '&=': b_and_equals,
-        '|=': b_or_equals,
-        ':-': or_equals,
-        'majority=': majority_equals,
-    }
-
-
-    if agg_decl[item.fn] == ':=':
-        return LastEquals(item, agg_decl[item.fn])
-
-    elif agg_decl[item.fn] == 'bag=':
-        return BAggregator(item, agg_decl[item.fn])
-
-    elif agg_decl[item.fn] == 'set=':
-        return SetEquals(item, agg_decl[item.fn])
+
+
+def majority_equals(a):
+    [(k,_)] = a.most_common(1)
+    return k
+
+def max_equals(a):
+    s = [k for k, m in a.iteritems() if m > 0]
+    if len(s):
+        return max(s)
+
+def min_equals(a):
+    s = [k for k, m in a.iteritems() if m > 0]
+    if len(s):
+        return min(s)
+
+def plus_equals(a):
+    s = [k*m for k, m in a.iteritems() if m != 0]
+    if len(s):
+        return reduce(operator.add, s)
+
+def times_equals(a):
+    s = [k**m for k, m in a.iteritems() if m != 0]
+    if len(s):
+        return reduce(operator.mul, s)
+
+def and_equals(a):
+    s = [k for k, m in a.iteritems() if m > 0]
+    if len(s):
+        return reduce(lambda x,y: x and y, s)
+
+def or_equals(a):
+    s = [k for k, m in a.iteritems() if m > 0]
+    if len(s):
+        return reduce(lambda x,y: x or y, s)
+
+def b_and_equals(a):
+    s = [k for k, m in a.iteritems() if m > 0]
+    if len(s):
+        return reduce(operator.and_, s)
+
+def b_or_equals(a):
+    s = [k for k, m in a.iteritems() if m > 0]
+    if len(s):
+        return reduce(operator.or_, s)
+
+# map names to functions
+defs = {
+    'max=': max_equals,
+    'min=': min_equals,
+    '+=': plus_equals,
+    '*=': times_equals,
+    'and=': and_equals,
+    'or=': or_equals,
+    '&=': b_and_equals,
+    '|=': b_or_equals,
+    ':-': or_equals,
+    'majority=': majority_equals,
+}
+
+
+def aggregator(name):
+    "Create aggregator by ``name``."
+
+    if name is None:
+        return None
+
+    if name == ':=':
+        return LastEquals(name)
+
+    elif name == 'bag=':
+        return BAggregator(name)
+
+    elif name == 'set=':
+        return SetEquals(name)
 
     else:
-        return MultisetAggregator(item, agg_decl[item.fn], defs[agg_decl[item.fn]])
+        return MultisetAggregator(name, defs[name])
index b52b32a0b91d4f6bee15586cf124f381deed4029..50049a96e628cb6a383af301fbf1b1abf04059ac 100644 (file)
@@ -46,18 +46,6 @@ Warnings/lint checking
  - "initializers" aren't just initializers, they are the fully-naive bottom-up
    inference rules.
 
- - XXX: maybe the chart should store pretty printed term and a reference to the
-   aggregator (each item get's its own aggregator to avoid a hash lookup).
-
- - XXX: should we store value and aggregators separate from others columns? that
-   is, separate the chart and intern table.
-
-     timv: My new though on this is to store Term objects in the Chart. These
-     objects will contain a mutable reference to value and references to
-     aggregator and arguments.
-
-     TODO: need ot be more strict about interning Terms.
-
  - XXX: we should probably fuse update handlers instead of dispatching to each
    one independently.
 
@@ -136,7 +124,7 @@ from collections import defaultdict, namedtuple
 from argparse import ArgumentParser
 
 from utils import ip, red, green, blue, magenta, yellow, dynahome
-from defn import agg_bind
+from defn import aggregator
 
 
 class AggregatorConflict(Exception):
@@ -149,11 +137,6 @@ class chart_indirect(dict):
         c = self[key] = Chart(name = key, ncols = arity + 1)  # +1 for value
         return c
 
-class aggregator_indirect(dict):
-    def __missing__(self, item):
-        a = agg_bind(item, agg_decl)
-        self[item] = a
-        return a
 
 # when a new rule comes along it puts a string in the following dictionary
 class aggregator_declaration(object):
@@ -165,14 +148,16 @@ class aggregator_declaration(object):
                                      "set to %r." % (key, self.map[key], value))
         self.map[key] = value
     def __getitem__(self, key):
-        return self.map[key]
+        try:
+            return self.map[key]
+        except KeyError:
+            return None
 
 
 trace = None
 _delete = False
 agenda = set()
 agg_decl = aggregator_declaration()
-aggregator = aggregator_indirect()
 chart = chart_indirect()
 
 
@@ -189,20 +174,26 @@ def dump_charts(out=sys.stdout):
         print >> out
 
 
+def notimplemented(*_,**__):
+    raise NotImplementedError
+
+
 # TODO: codegen should output a derive Term instance for each functor
 class Term(namedtuple('Term', 'fn args'), object):
 
-    def __init__(self, fn, idx):
+    def __init__(self, fn, args):
         self._value = None
-        super(Term, self).__init__(fn, idx)
-
-    @property
-    def aggregator(self):
-        return aggregator[self]    # TODO: avoid this lookup
+        self.aggregator = None
+        super(Term, self).__init__(fn, args)
 
     def __repr__(self):
         return pretty(self)
 
+    __add__ \
+        = __sub__ \
+        = __mul__ \
+        = notimplemented
+
 #    @property
 #    def value(self):
 #        return self._value
@@ -212,14 +203,6 @@ class Term(namedtuple('Term', 'fn args'), object):
 #        assert not isinstance(val, tuple) or isinstance(val, Term)
 #        self._value = val
 
-# TODO: we don't story Term objects in the chart yet.. so we need to use the
-# namedtuple's __eq__ method.
-#
-# TODO: after interning we shouldn't need deep equality.
-#
-#    def __eq__(self, other):
-#        assert isinstance(other, Term), other
-#        return other.fn == self.fn and other.args == self.args
 
 def pretty(item):
     "Pretty print a term. Will retrieve the complete (ground) term from the chart."
@@ -304,12 +287,12 @@ class Chart(object):
 
         assert isinstance(args, tuple) and not isinstance(args, Term)
 
-
         # debugging check: row is not already in chart.
         assert self.lookup(args) is None, '%r already in chart with value %r' % (args, val)
 
         self.intern[args] = term = Term(self.name, args)
         term.value = val
+        term.aggregator = aggregator(agg_decl[self.name])
 
         for i, x in enumerate(args):
             self.ix[i][x].add(term)