]> hydra-www.ietfng.org Git - dyna2/commitdiff
Progress toward pickling Interpreter state. Using import mechanism instead of
authortimv <tim.f.vieira@gmail.com>
Wed, 12 Jun 2013 20:06:28 +0000 (16:06 -0400)
committertimv <tim.f.vieira@gmail.com>
Wed, 12 Jun 2013 20:06:28 +0000 (16:06 -0400)
execfile to load new code.

Refactor aggregator defns.

src/Dyna/Backend/Python/Backend.hs
src/Dyna/Backend/Python/chart.py
src/Dyna/Backend/Python/defn.py
src/Dyna/Backend/Python/interpreter.py
src/Dyna/Backend/Python/repl.py
src/Dyna/Backend/Python/term.py

index 7bbd6e1c9a7c70faba8fd7071d2f8369a2d8cbad..4d08cd2a1ae632495c920ec2e50a273d28bd172d 100644 (file)
@@ -296,7 +296,7 @@ printInitializer fh rule cost dope = do
                    `above` (indent 4 $ printPlanHeader rule cost Nothing)
                    `above` pdope dope
                    <> line
-                   <> "_initializers.append((" <> (pretty $ r_index rule) <> ", _" <> "))"
+                   <> "initializers.append((" <> (pretty $ r_index rule) <> ", _" <> "))"
                    <> line
                    <> line
                    <> line
@@ -313,7 +313,7 @@ printUpdate fh rule cost evalix (Just (f,a)) (hv,v) dope = do
                    `above` (indent 4 $ printPlanHeader rule cost (Just evalix))
                    `above` pdope dope
                    <> line
-                   <> "_updaters.append((" <> (pfa f a) <> "," <> (pretty $ r_index rule) <> ",_))"
+                   <> "updaters.append((" <> (pfa f a) <> "," <> (pretty $ r_index rule) <> ",_))"
                    <> line
                    <> line
                    <> line
@@ -333,9 +333,13 @@ driver am um {-qm-} is pr fh = do
   hPutStrLn fh "\"\"\""
   hPutStrLn fh ""
 
+  hPutStrLn fh $ "agg_decl = {}"
+  hPutStrLn fh $ "updaters = []"
+  hPutStrLn fh $ "initializers = []"
+
   -- Aggregation mapping
   forM_ (M.toList am) $ \((f,a),v) -> do
-     hPutStrLn fh $ show $    "_agg_decl"
+     hPutStrLn fh $ show $    "agg_decl"
                            <> brackets (dquotes $ pretty f <> "/" <> pretty a)
                            <+> equals <+> (dquotes $ pretty v)
 
index b2fb50cd574571554c5e8560076c091f83f68f6f..7cedf179cad1cb1f71f0a7d33b74c8030a3ec136 100644 (file)
@@ -1,17 +1,20 @@
 from collections import defaultdict
 from utils import notimplemented
-
+from defn import aggregator
 from term import Term, _repr
 
 
 class Chart(object):
 
-    def __init__(self, name, arity, new_aggregator):
+    def __init__(self, name, arity, agg_name):
         self.name = name
         self.arity = arity
         self.intern = {}   # args -> term
         self.ix = [defaultdict(set) for _ in xrange(arity)]
-        self.new_aggregator = new_aggregator
+        self.agg_name = agg_name
+
+    def new_aggregator(self):
+        return aggregator(self.agg_name)
 
     def __repr__(self):
         rows = [term for term in self.intern.values() if term.value is not None]
index 0c7c5cb9bd8b51cc6de48f3e3df7790b44280def..8d39cc6178175119347d6bcd40b7d834465f6604 100644 (file)
@@ -6,8 +6,6 @@ from collections import Counter
 
 
 class Aggregator(object):
-    def __init__(self, name):
-        self.name = name
     def fold(self):
         raise NotImplementedError
     def inc(self, val, ruleix, variables):
@@ -16,24 +14,16 @@ class Aggregator(object):
         raise NotImplementedError
     def clear(self):
         raise NotImplementedError
-    def __repr__(self):
-        return 'Aggregator(%r)' % (self.name)
 
 
 class BAggregator(Counter, Aggregator):
-    def __init__(self, name, folder):
-        self.folder = folder
-        Aggregator.__init__(self, name)
-        Counter.__init__(self)
-    def fold(self):
-        return self.folder(self)
     def inc(self, val, ruleix, variables):
         self[val] += 1
     def dec(self, val, ruleix, variables):
         self[val] -= 1
     def fromkeys(self, *_):
         assert False, "This method should never be called."
-
+        
 
 class PlusEquals(object):
     __slots__ = 'pos', 'neg'
@@ -80,60 +70,71 @@ class DictEquals(BAggregator):
         self[val, vs] -= 1
 
     def fold(self):
-        return list((x[0], dict(x[1])) for x, cnt in self.iteritems())
+        return list((v, dict(b)) for (v, b), cnt in self.iteritems())
 
 
-def majority_equals(a):
-    [(k,_)] = a.most_common(1)
-    return k
+class majority_equals(BAggregator):
+    def fold(self):
+        [(k,_)] = self.most_common(1)
+        return k
 
-def max_equals(a):
-    s = [k for k, m in a.iteritems() if m > 0]
-    if len(s):
-        return max(s)
+class max_equals(BAggregator):
+    def fold(self):
+        s = [k for k, m in self.iteritems() if m > 0]
+        if len(s):
+            return max(s)
 
-def min_equals(a):
-    s = [k for k, m in a.iteritems() if m > 0]
-    if len(s):
-        return min(s)
+class min_equals(BAggregator):
+    def fold(self):
+        s = [k for k, m in self.iteritems() if m > 0]
+        if len(s):
+            return min(s)
 
-def plus_equals(a):
-    s = [k*m for k, m in a.iteritems() if m != 0]
-    if len(s):
-        return reduce(operator.add, s)
+class plus_equals(BAggregator):
+    def fold(self):
+        s = [k*m for k, m in self.iteritems() if m != 0]
+        if len(s):
+            return reduce(operator.add, s)
 
-def times_equals(a):
-    s = [k**m for k, m in a.iteritems() if m != 0]
-    if len(s):
-        return reduce(operator.mul, s)
+class times_equals(BAggregator):
+    def fold(self):
+        s = [k**m for k, m in self.iteritems() if m != 0]
+        if len(s):
+            return reduce(operator.mul, s)
 
-def and_equals(a):
-    s = [k for k, m in a.iteritems() if m > 0]
-    if len(s):
-        return reduce(lambda x,y: x and y, s)
+class and_equals(BAggregator):
+    def fold(self):
+        s = [k for k, m in self.iteritems() if m > 0]
+        if len(s):
+            return reduce(lambda x,y: x and y, s)
 
-def or_equals(a):
-    s = [k for k, m in a.iteritems() if m > 0]
-    if len(s):
-        return reduce(lambda x,y: x or y, s)
+class or_equals(BAggregator):
+    def fold(self):
+        s = [k for k, m in self.iteritems() if m > 0]
+        if len(s):
+            return reduce(lambda x,y: x or y, s)
 
-def b_and_equals(a):
-    s = [k for k, m in a.iteritems() if m > 0]
-    if len(s):
-        return reduce(operator.and_, s)
+class b_and_equals(BAggregator):
+    def fold(self):
+        s = [k for k, m in self.iteritems() if m > 0]
+        if len(s):
+            return reduce(operator.and_, s)
 
-def b_or_equals(a):
-    s = [k for k, m in a.iteritems() if m > 0]
-    if len(s):
-        return reduce(operator.or_, s)
+class b_or_equals(BAggregator):
+    def fold(self):
+        s = [k for k, m in self.iteritems() if m > 0]
+        if len(s):
+            return reduce(operator.or_, s)
 
-def set_equals(a):
-    s = {x for x, m in a.iteritems() if m > 0}
-    if len(s):
-        return s
+class set_equals(BAggregator):
+    def fold(self):
+        s = {x for x, m in self.iteritems() if m > 0}
+        if len(s):
+            return s
 
-def bag_equals(a):
-    return Counter(a)
+class bag_equals(BAggregator):
+    def fold(self):
+        return Counter(self)
 
 
 # map names to functions
@@ -152,7 +153,6 @@ defs = {
     'bag=': bag_equals,
 }
 
-
 def aggregator(name):
     "Create aggregator by ``name``."
 
@@ -160,13 +160,10 @@ def aggregator(name):
         return None
 
     if name == ':=':
-        return ColonEquals(name, folder=None)
-
-#    elif name == '+=':
-#        return PlusEquals()
+        return ColonEquals()
 
     elif name == 'dict=':
-        return DictEquals(name, folder=None)
+        return DictEquals()
 
     else:
-        return BAggregator(name, defs[name])
+        return defs[name]()
index fabd396c7c0dc4cfc0f570194817a5fc79d98644..da10aeb0a60726adc134c5f41c7a184352f52521 100644 (file)
@@ -193,7 +193,7 @@ What is null?
 """
 
 from __future__ import division
-import os, sys
+import os, sys, imp
 from collections import defaultdict
 from argparse import ArgumentParser
 
@@ -204,7 +204,7 @@ from chart import Chart, Term, _repr
 from defn import aggregator
 from utils import ip, red, green, blue, magenta, yellow, \
     notimplemented, parse_attrs, ddict, dynac, \
-    DynaCompilerError, DynaInitializerException, AggregatorConflict
+    DynaCompilerError, DynaInitializerException
 from prioritydict import prioritydict
 from config import dotdynadir, dynahome
 
@@ -235,6 +235,18 @@ class Rule(object):
         return 'Rule(%s, %r)' % (self.idx, self.src)
 
 
+# TODO: yuck, hopefully temporary measure to support pickling the Interpreter's
+# state
+class foo(dict):
+    def __init__(self, agg_name):
+        self.agg_name = agg_name
+        super(foo, self).__init__()
+    def __missing__(self, fn):
+        arity = int(fn.split('/')[-1])
+        self[fn] = c = Chart(fn, arity, self.agg_name[fn])
+        return c
+
+
 class Interpreter(object):
 
     def __init__(self):
@@ -246,20 +258,31 @@ class Interpreter(object):
         self.agenda = prioritydict()
         self.parser_state = ''
 
-        def newchart(fn):
-            arity = int(fn.split('/')[-1])
-            return Chart(fn, arity, lambda: aggregator(self.agg_name[fn]))
-
-        self.chart = ddict(newchart)
+        self.chart = foo(self.agg_name)
         self.rules = ddict(Rule)
         self.errors = {}
 
+#    def __getstate__(self):
+#        return ((self.chart,
+#                 self.agenda,
+#                 self.agenda,
+#                 self.errors,
+#                 self.agg_name,
+#                 self.parser_state),
+#                '\n'.join(self.rules[i].src for i in sorted(self.rules)))
+
+#    def __setstate__(self, state):
+#        ((self.chart, self.agenda, self.agenda, self.errors, self.agg_name, self.parser_state), code) = state
+#        self.edges = defaultdict(set)
+#        self.updaters = defaultdict(list)
+#        self.rules = ddict(Rule)
+#        self.do(self.dynac_code(code))
+
     def new_fn(self, fn, agg):
         # check for aggregator conflict.
         if fn not in self.agg_name:
             self.agg_name[fn] = agg
-        if self.agg_name[fn] != agg:
-            raise AggregatorConflict(fn, self.agg_name[fn], agg)
+        assert self.agg_name[fn] == agg, (fn, self.agg_name[fn], agg)
 
     def collect_edges(self):
         """
@@ -304,6 +327,27 @@ class Interpreter(object):
     def dump_rules(self):
         for i in sorted(self.rules):
             print '%3s: %s' % (i, self.rules[i].src)
+#
+#    def query(self, q):
+#        if q.endswith('.'):
+#            print "Queries don't end with a dot."
+#            return
+#
+#        query = 'out("%s") dict= %s.' % (q, q)
+#
+#        src = self.dynac_code(query)   # might raise DynaCompilerError
+#        self.do(src)
+#
+#        try:
+#            [(_, _, results)] = self.chart['out/1'][q,:]
+#        except ValueError:
+#            print 'No results.'
+#            return
+#
+#        for val, bindings in results:
+#            print '   ', val, 'when', bindings
+#        print
+
 
     def build(self, fn, *args):
         # TODO: codegen should handle true/0 is True and false/0 is False
@@ -320,28 +364,28 @@ class Interpreter(object):
 
         return self.chart[fn].insert(args)
 
-    def retract_item(self, item):
-        """
-        For the moment we only correctly retract leaves.
-
-        If you retract a non-leaf item, you run the risk of it being
-        rederived. In the case of cyclic programs the derivation might be the
-        same or different.
-        """
-        # and now, for something truely horrendous -- look up an item by it's
-        # string value! This could fail because of whitespace or trivial
-        # formatting differences.
-        items = {}
-        for c in self.chart.values():
-            for i in c.intern.values():
-                items[str(i)] = i
-        try:
-            item = items[item]
-        except KeyError:
-            print 'item not found. This could be because of a trivial formatting differences...'
-            return
-        self.emit(item, item.value, None, sys.maxint, delete=True)
-        return self.go()
+#    def retract_item(self, item):
+#        """
+#        For the moment we only correctly retract leaves.
+#
+#        If you retract a non-leaf item, you run the risk of it being
+#        rederived. In the case of cyclic programs the derivation might be the
+#        same or different.
+#        """
+#        # and now, for something truely horrendous -- look up an item by it's
+#        # string value! This could fail because of whitespace or trivial
+#        # formatting differences.
+#        items = {}
+#        for c in self.chart.values():
+#            for i in c.intern.values():
+#                items[str(i)] = i
+#        try:
+#            item = items[item]
+#        except KeyError:
+#            print 'item not found. This could be because of a trivial formatting differences...'
+#            return
+#        self.emit(item, item.value, None, sys.maxint, delete=True)
+#        return self.go()
 
     def retract_rule(self, idx):
         "Retract rule and all of it's edges."
@@ -381,7 +425,7 @@ class Interpreter(object):
             was = item.value
             try:
                 now = item.aggregator.fold()
-            except (ZeroDivisionError, TypeError, KeyboardInterrupt) as e:
+            except (ZeroDivisionError, TypeError, KeyboardInterrupt, NotImplementedError) as e:
                 errors[item] = ('failed to aggregate %r' % item.aggregator, [(e, None)])
                 continue
             if was == now:
@@ -404,7 +448,7 @@ class Interpreter(object):
 
     def update_dispatcher(self, item, val, delete):
         """
-        Passes update to relevant handlers.
+        Passes update to relevant handlers. Catches errors.
         """
 
         # store emissions, make sure all of them succeed before propagating
@@ -454,7 +498,7 @@ class Interpreter(object):
 #        self.agenda[item] = 0   # everything is high priority
         self.agenda[item] = time()  # FIFO
 
-    def repl(self, hist):
+    def repl(self, hist = dotdynadir / 'dyna.hist'):
         import repl
         repl.REPL(self, hist).cmdloop()
 
@@ -476,24 +520,28 @@ class Interpreter(object):
 #            print >> self.trace, magenta % 'Loading new code'
 #            print >> self.trace, yellow % h.read()
 
-        env = {'_initializers': [], '_updaters': [], '_agg_decl': {},
-               'chart': self.chart, 'build': self.build, 'peel': peel,
-               'parser_state': None, 'uniform': uniform,
-               'log': log, 'exp': exp, 'sqrt': sqrt}
 
         # load generated code.
-        execfile(filename, env)
+#        execfile(filename, env)
+
+        env = imp.load_source('module.name', filename)
+
+        for k,v in [('chart', self.chart),
+                    ('build', self.build),
+                    ('peel', peel),
+                    ('uniform', uniform), ('log', log), ('exp', exp), ('sqrt', sqrt)]:
+            setattr(env, k, v)
 
         emits = []
         def _emit(*args):
             emits.append(args)
 
-        for k, v in env['_agg_decl'].items():
+        for k, v in env.agg_decl.items():
             self.new_fn(k, v)
 
         try:
             # only run new initializers
-            for _, init in env['_initializers']:
+            for _, init in env.initializers:
                 init(emit=_emit)
 
         except (TypeError, ZeroDivisionError) as e:
@@ -505,13 +553,13 @@ class Interpreter(object):
             # in the middle of the following blocK?
 
             # add new updaters
-            for fn, r, h in env['_updaters']:
+            for fn, r, h in env.updaters:
                 self.new_updater(fn, r, h)
             # add new initializers
-            for r, h in env['_initializers']:
+            for r, h in env.initializers:
                 self.new_initializer(r, h)
             # accept the new parser state
-            self.parser_state = env['parser_state']
+            self.parser_state = env.parser_state
             # process emits
             for e in emits:
                 self.emit(*e, delete=False)
@@ -586,7 +634,7 @@ def main():
 
     if args.postprocess is not None:
         try:
-            pp =__import__(args.postprocess)
+            pp = __import__(args.postprocess)
         except ImportError:
             print ('ERROR: No postprocessor named %r' % args.postprocess)
             return
@@ -656,16 +704,22 @@ def main():
             interp.repl(hist = args.source + '.hist')
 
     else:
-        interp.repl(hist = '/tmp/dyna.hist')
+        interp.repl()
 
     if args.draw:
         interp.draw()
 
+#    interp.query('phrase(X,I,K)')
+
+#    import cPickle
+#    out = cPickle.dumps(interp)  # XXX:
+#    interp2 = cPickle.loads(out)  # XXX:
+#    interp2.repl()
+
     if args.postprocess is not None:
         # TODO: import and call main method instead.
         pp.main(interp)
 
 
-
 if __name__ == '__main__':
     main()
index 1de1e455fdc0f17a1088dd564e75e81fd301e744..aac79fe02ef6d95dcb554029a04969a2dcaa1db2 100644 (file)
@@ -1,7 +1,7 @@
-import os, sys
+import os
 import cmd, readline
 import interpreter
-from utils import blue, yellow, green, magenta, ip, DynaCompilerError, AggregatorConflict, DynaInitializerException
+from utils import blue, yellow, green, magenta, ip, DynaCompilerError, DynaInitializerException
 from chart import _repr
 from config import dotdynadir
 import debug
@@ -30,8 +30,8 @@ class REPL(cmd.Cmd, object):
     def do_retract_rule(self, idx):
         self.interp.retract_rule(int(idx))
 
-    def do_retract_item(self, item):
-        self.interp.retract_item(item)
+#    def do_retract_item(self, item):
+#       self.interp.retract_item(item)
 
     def do_exit(self, _):
         readline.write_history_file(self.hist)
@@ -100,9 +100,9 @@ class REPL(cmd.Cmd, object):
             return
         try:
             src = self.interp.dynac_code(line)   # might raise DynaCompilerError
-            changed = self.interp.do(src)        # throws AggregatorConflict
+            changed = self.interp.do(src)
 
-        except (AggregatorConflict, DynaInitializerException, DynaCompilerError) as e:
+        except (DynaInitializerException, DynaCompilerError) as e:
             print type(e).__name__ + ':'
             print e
             print '> new rule(s) were not added to program.'
index dd88c0c695651284f085916824544927e4b5980c..441d14352b4ecfd5515595107115ec0afb7f52f0 100644 (file)
@@ -29,11 +29,11 @@ class Term(object):
             return fn
         return '%s(%s)' % (fn, ','.join(map(_repr, self.args)))
 
-    def __getstate__(self):
-        return (self.fn, self.args, self.value, self.aggregator)
+#    def __getstate__(self):
+#        return (self.fn, self.args, self.value, self.aggregator)
 
-    def __setstate__(self, state):
-        (self.fn, self.args, self.value, self.aggregator) = state
+#    def __setstate__(self, state):
+#        (self.fn, self.args, self.value, self.aggregator) = state
 
     __add__ = __sub__ = __mul__ = notimplemented