From 4a5b5095500a0d23f0d3def2d3f976acc59347fd Mon Sep 17 00:00:00 2001 From: timv Date: Wed, 12 Jun 2013 16:06:28 -0400 Subject: [PATCH] Progress toward pickling Interpreter state. Using import mechanism instead of execfile to load new code. Refactor aggregator defns. --- src/Dyna/Backend/Python/Backend.hs | 10 +- src/Dyna/Backend/Python/chart.py | 9 +- src/Dyna/Backend/Python/defn.py | 117 ++++++++++--------- src/Dyna/Backend/Python/interpreter.py | 148 +++++++++++++++++-------- src/Dyna/Backend/Python/repl.py | 12 +- src/Dyna/Backend/Python/term.py | 8 +- 6 files changed, 181 insertions(+), 123 deletions(-) diff --git a/src/Dyna/Backend/Python/Backend.hs b/src/Dyna/Backend/Python/Backend.hs index 7bbd6e1..4d08cd2 100644 --- a/src/Dyna/Backend/Python/Backend.hs +++ b/src/Dyna/Backend/Python/Backend.hs @@ -296,7 +296,7 @@ printInitializer fh rule cost dope = do `above` (indent 4 $ printPlanHeader rule cost Nothing) `above` pdope dope <> line - <> "_initializers.append((" <> (pretty $ r_index rule) <> ", _" <> "))" + <> "initializers.append((" <> (pretty $ r_index rule) <> ", _" <> "))" <> line <> line <> line @@ -313,7 +313,7 @@ printUpdate fh rule cost evalix (Just (f,a)) (hv,v) dope = do `above` (indent 4 $ printPlanHeader rule cost (Just evalix)) `above` pdope dope <> line - <> "_updaters.append((" <> (pfa f a) <> "," <> (pretty $ r_index rule) <> ",_))" + <> "updaters.append((" <> (pfa f a) <> "," <> (pretty $ r_index rule) <> ",_))" <> line <> line <> line @@ -333,9 +333,13 @@ driver am um {-qm-} is pr fh = do hPutStrLn fh "\"\"\"" hPutStrLn fh "" + hPutStrLn fh $ "agg_decl = {}" + hPutStrLn fh $ "updaters = []" + hPutStrLn fh $ "initializers = []" + -- Aggregation mapping forM_ (M.toList am) $ \((f,a),v) -> do - hPutStrLn fh $ show $ "_agg_decl" + hPutStrLn fh $ show $ "agg_decl" <> brackets (dquotes $ pretty f <> "/" <> pretty a) <+> equals <+> (dquotes $ pretty v) diff --git a/src/Dyna/Backend/Python/chart.py b/src/Dyna/Backend/Python/chart.py index b2fb50c..7cedf17 100644 --- a/src/Dyna/Backend/Python/chart.py +++ b/src/Dyna/Backend/Python/chart.py @@ -1,17 +1,20 @@ from collections import defaultdict from utils import notimplemented - +from defn import aggregator from term import Term, _repr class Chart(object): - def __init__(self, name, arity, new_aggregator): + def __init__(self, name, arity, agg_name): self.name = name self.arity = arity self.intern = {} # args -> term self.ix = [defaultdict(set) for _ in xrange(arity)] - self.new_aggregator = new_aggregator + self.agg_name = agg_name + + def new_aggregator(self): + return aggregator(self.agg_name) def __repr__(self): rows = [term for term in self.intern.values() if term.value is not None] diff --git a/src/Dyna/Backend/Python/defn.py b/src/Dyna/Backend/Python/defn.py index 0c7c5cb..8d39cc6 100644 --- a/src/Dyna/Backend/Python/defn.py +++ b/src/Dyna/Backend/Python/defn.py @@ -6,8 +6,6 @@ from collections import Counter class Aggregator(object): - def __init__(self, name): - self.name = name def fold(self): raise NotImplementedError def inc(self, val, ruleix, variables): @@ -16,24 +14,16 @@ class Aggregator(object): raise NotImplementedError def clear(self): raise NotImplementedError - def __repr__(self): - return 'Aggregator(%r)' % (self.name) class BAggregator(Counter, Aggregator): - def __init__(self, name, folder): - self.folder = folder - Aggregator.__init__(self, name) - Counter.__init__(self) - def fold(self): - return self.folder(self) def inc(self, val, ruleix, variables): self[val] += 1 def dec(self, val, ruleix, variables): self[val] -= 1 def fromkeys(self, *_): assert False, "This method should never be called." - + class PlusEquals(object): __slots__ = 'pos', 'neg' @@ -80,60 +70,71 @@ class DictEquals(BAggregator): self[val, vs] -= 1 def fold(self): - return list((x[0], dict(x[1])) for x, cnt in self.iteritems()) + return list((v, dict(b)) for (v, b), cnt in self.iteritems()) -def majority_equals(a): - [(k,_)] = a.most_common(1) - return k +class majority_equals(BAggregator): + def fold(self): + [(k,_)] = self.most_common(1) + return k -def max_equals(a): - s = [k for k, m in a.iteritems() if m > 0] - if len(s): - return max(s) +class max_equals(BAggregator): + def fold(self): + s = [k for k, m in self.iteritems() if m > 0] + if len(s): + return max(s) -def min_equals(a): - s = [k for k, m in a.iteritems() if m > 0] - if len(s): - return min(s) +class min_equals(BAggregator): + def fold(self): + s = [k for k, m in self.iteritems() if m > 0] + if len(s): + return min(s) -def plus_equals(a): - s = [k*m for k, m in a.iteritems() if m != 0] - if len(s): - return reduce(operator.add, s) +class plus_equals(BAggregator): + def fold(self): + s = [k*m for k, m in self.iteritems() if m != 0] + if len(s): + return reduce(operator.add, s) -def times_equals(a): - s = [k**m for k, m in a.iteritems() if m != 0] - if len(s): - return reduce(operator.mul, s) +class times_equals(BAggregator): + def fold(self): + s = [k**m for k, m in self.iteritems() if m != 0] + if len(s): + return reduce(operator.mul, s) -def and_equals(a): - s = [k for k, m in a.iteritems() if m > 0] - if len(s): - return reduce(lambda x,y: x and y, s) +class and_equals(BAggregator): + def fold(self): + s = [k for k, m in self.iteritems() if m > 0] + if len(s): + return reduce(lambda x,y: x and y, s) -def or_equals(a): - s = [k for k, m in a.iteritems() if m > 0] - if len(s): - return reduce(lambda x,y: x or y, s) +class or_equals(BAggregator): + def fold(self): + s = [k for k, m in self.iteritems() if m > 0] + if len(s): + return reduce(lambda x,y: x or y, s) -def b_and_equals(a): - s = [k for k, m in a.iteritems() if m > 0] - if len(s): - return reduce(operator.and_, s) +class b_and_equals(BAggregator): + def fold(self): + s = [k for k, m in self.iteritems() if m > 0] + if len(s): + return reduce(operator.and_, s) -def b_or_equals(a): - s = [k for k, m in a.iteritems() if m > 0] - if len(s): - return reduce(operator.or_, s) +class b_or_equals(BAggregator): + def fold(self): + s = [k for k, m in self.iteritems() if m > 0] + if len(s): + return reduce(operator.or_, s) -def set_equals(a): - s = {x for x, m in a.iteritems() if m > 0} - if len(s): - return s +class set_equals(BAggregator): + def fold(self): + s = {x for x, m in self.iteritems() if m > 0} + if len(s): + return s -def bag_equals(a): - return Counter(a) +class bag_equals(BAggregator): + def fold(self): + return Counter(self) # map names to functions @@ -152,7 +153,6 @@ defs = { 'bag=': bag_equals, } - def aggregator(name): "Create aggregator by ``name``." @@ -160,13 +160,10 @@ def aggregator(name): return None if name == ':=': - return ColonEquals(name, folder=None) - -# elif name == '+=': -# return PlusEquals() + return ColonEquals() elif name == 'dict=': - return DictEquals(name, folder=None) + return DictEquals() else: - return BAggregator(name, defs[name]) + return defs[name]() diff --git a/src/Dyna/Backend/Python/interpreter.py b/src/Dyna/Backend/Python/interpreter.py index fabd396..da10aeb 100644 --- a/src/Dyna/Backend/Python/interpreter.py +++ b/src/Dyna/Backend/Python/interpreter.py @@ -193,7 +193,7 @@ What is null? """ from __future__ import division -import os, sys +import os, sys, imp from collections import defaultdict from argparse import ArgumentParser @@ -204,7 +204,7 @@ from chart import Chart, Term, _repr from defn import aggregator from utils import ip, red, green, blue, magenta, yellow, \ notimplemented, parse_attrs, ddict, dynac, \ - DynaCompilerError, DynaInitializerException, AggregatorConflict + DynaCompilerError, DynaInitializerException from prioritydict import prioritydict from config import dotdynadir, dynahome @@ -235,6 +235,18 @@ class Rule(object): return 'Rule(%s, %r)' % (self.idx, self.src) +# TODO: yuck, hopefully temporary measure to support pickling the Interpreter's +# state +class foo(dict): + def __init__(self, agg_name): + self.agg_name = agg_name + super(foo, self).__init__() + def __missing__(self, fn): + arity = int(fn.split('/')[-1]) + self[fn] = c = Chart(fn, arity, self.agg_name[fn]) + return c + + class Interpreter(object): def __init__(self): @@ -246,20 +258,31 @@ class Interpreter(object): self.agenda = prioritydict() self.parser_state = '' - def newchart(fn): - arity = int(fn.split('/')[-1]) - return Chart(fn, arity, lambda: aggregator(self.agg_name[fn])) - - self.chart = ddict(newchart) + self.chart = foo(self.agg_name) self.rules = ddict(Rule) self.errors = {} +# def __getstate__(self): +# return ((self.chart, +# self.agenda, +# self.agenda, +# self.errors, +# self.agg_name, +# self.parser_state), +# '\n'.join(self.rules[i].src for i in sorted(self.rules))) + +# def __setstate__(self, state): +# ((self.chart, self.agenda, self.agenda, self.errors, self.agg_name, self.parser_state), code) = state +# self.edges = defaultdict(set) +# self.updaters = defaultdict(list) +# self.rules = ddict(Rule) +# self.do(self.dynac_code(code)) + def new_fn(self, fn, agg): # check for aggregator conflict. if fn not in self.agg_name: self.agg_name[fn] = agg - if self.agg_name[fn] != agg: - raise AggregatorConflict(fn, self.agg_name[fn], agg) + assert self.agg_name[fn] == agg, (fn, self.agg_name[fn], agg) def collect_edges(self): """ @@ -304,6 +327,27 @@ class Interpreter(object): def dump_rules(self): for i in sorted(self.rules): print '%3s: %s' % (i, self.rules[i].src) +# +# def query(self, q): +# if q.endswith('.'): +# print "Queries don't end with a dot." +# return +# +# query = 'out("%s") dict= %s.' % (q, q) +# +# src = self.dynac_code(query) # might raise DynaCompilerError +# self.do(src) +# +# try: +# [(_, _, results)] = self.chart['out/1'][q,:] +# except ValueError: +# print 'No results.' +# return +# +# for val, bindings in results: +# print ' ', val, 'when', bindings +# print + def build(self, fn, *args): # TODO: codegen should handle true/0 is True and false/0 is False @@ -320,28 +364,28 @@ class Interpreter(object): return self.chart[fn].insert(args) - def retract_item(self, item): - """ - For the moment we only correctly retract leaves. - - If you retract a non-leaf item, you run the risk of it being - rederived. In the case of cyclic programs the derivation might be the - same or different. - """ - # and now, for something truely horrendous -- look up an item by it's - # string value! This could fail because of whitespace or trivial - # formatting differences. - items = {} - for c in self.chart.values(): - for i in c.intern.values(): - items[str(i)] = i - try: - item = items[item] - except KeyError: - print 'item not found. This could be because of a trivial formatting differences...' - return - self.emit(item, item.value, None, sys.maxint, delete=True) - return self.go() +# def retract_item(self, item): +# """ +# For the moment we only correctly retract leaves. +# +# If you retract a non-leaf item, you run the risk of it being +# rederived. In the case of cyclic programs the derivation might be the +# same or different. +# """ +# # and now, for something truely horrendous -- look up an item by it's +# # string value! This could fail because of whitespace or trivial +# # formatting differences. +# items = {} +# for c in self.chart.values(): +# for i in c.intern.values(): +# items[str(i)] = i +# try: +# item = items[item] +# except KeyError: +# print 'item not found. This could be because of a trivial formatting differences...' +# return +# self.emit(item, item.value, None, sys.maxint, delete=True) +# return self.go() def retract_rule(self, idx): "Retract rule and all of it's edges." @@ -381,7 +425,7 @@ class Interpreter(object): was = item.value try: now = item.aggregator.fold() - except (ZeroDivisionError, TypeError, KeyboardInterrupt) as e: + except (ZeroDivisionError, TypeError, KeyboardInterrupt, NotImplementedError) as e: errors[item] = ('failed to aggregate %r' % item.aggregator, [(e, None)]) continue if was == now: @@ -404,7 +448,7 @@ class Interpreter(object): def update_dispatcher(self, item, val, delete): """ - Passes update to relevant handlers. + Passes update to relevant handlers. Catches errors. """ # store emissions, make sure all of them succeed before propagating @@ -454,7 +498,7 @@ class Interpreter(object): # self.agenda[item] = 0 # everything is high priority self.agenda[item] = time() # FIFO - def repl(self, hist): + def repl(self, hist = dotdynadir / 'dyna.hist'): import repl repl.REPL(self, hist).cmdloop() @@ -476,24 +520,28 @@ class Interpreter(object): # print >> self.trace, magenta % 'Loading new code' # print >> self.trace, yellow % h.read() - env = {'_initializers': [], '_updaters': [], '_agg_decl': {}, - 'chart': self.chart, 'build': self.build, 'peel': peel, - 'parser_state': None, 'uniform': uniform, - 'log': log, 'exp': exp, 'sqrt': sqrt} # load generated code. - execfile(filename, env) +# execfile(filename, env) + + env = imp.load_source('module.name', filename) + + for k,v in [('chart', self.chart), + ('build', self.build), + ('peel', peel), + ('uniform', uniform), ('log', log), ('exp', exp), ('sqrt', sqrt)]: + setattr(env, k, v) emits = [] def _emit(*args): emits.append(args) - for k, v in env['_agg_decl'].items(): + for k, v in env.agg_decl.items(): self.new_fn(k, v) try: # only run new initializers - for _, init in env['_initializers']: + for _, init in env.initializers: init(emit=_emit) except (TypeError, ZeroDivisionError) as e: @@ -505,13 +553,13 @@ class Interpreter(object): # in the middle of the following blocK? # add new updaters - for fn, r, h in env['_updaters']: + for fn, r, h in env.updaters: self.new_updater(fn, r, h) # add new initializers - for r, h in env['_initializers']: + for r, h in env.initializers: self.new_initializer(r, h) # accept the new parser state - self.parser_state = env['parser_state'] + self.parser_state = env.parser_state # process emits for e in emits: self.emit(*e, delete=False) @@ -586,7 +634,7 @@ def main(): if args.postprocess is not None: try: - pp =__import__(args.postprocess) + pp = __import__(args.postprocess) except ImportError: print ('ERROR: No postprocessor named %r' % args.postprocess) return @@ -656,16 +704,22 @@ def main(): interp.repl(hist = args.source + '.hist') else: - interp.repl(hist = '/tmp/dyna.hist') + interp.repl() if args.draw: interp.draw() +# interp.query('phrase(X,I,K)') + +# import cPickle +# out = cPickle.dumps(interp) # XXX: +# interp2 = cPickle.loads(out) # XXX: +# interp2.repl() + if args.postprocess is not None: # TODO: import and call main method instead. pp.main(interp) - if __name__ == '__main__': main() diff --git a/src/Dyna/Backend/Python/repl.py b/src/Dyna/Backend/Python/repl.py index 1de1e45..aac79fe 100644 --- a/src/Dyna/Backend/Python/repl.py +++ b/src/Dyna/Backend/Python/repl.py @@ -1,7 +1,7 @@ -import os, sys +import os import cmd, readline import interpreter -from utils import blue, yellow, green, magenta, ip, DynaCompilerError, AggregatorConflict, DynaInitializerException +from utils import blue, yellow, green, magenta, ip, DynaCompilerError, DynaInitializerException from chart import _repr from config import dotdynadir import debug @@ -30,8 +30,8 @@ class REPL(cmd.Cmd, object): def do_retract_rule(self, idx): self.interp.retract_rule(int(idx)) - def do_retract_item(self, item): - self.interp.retract_item(item) +# def do_retract_item(self, item): +# self.interp.retract_item(item) def do_exit(self, _): readline.write_history_file(self.hist) @@ -100,9 +100,9 @@ class REPL(cmd.Cmd, object): return try: src = self.interp.dynac_code(line) # might raise DynaCompilerError - changed = self.interp.do(src) # throws AggregatorConflict + changed = self.interp.do(src) - except (AggregatorConflict, DynaInitializerException, DynaCompilerError) as e: + except (DynaInitializerException, DynaCompilerError) as e: print type(e).__name__ + ':' print e print '> new rule(s) were not added to program.' diff --git a/src/Dyna/Backend/Python/term.py b/src/Dyna/Backend/Python/term.py index dd88c0c..441d143 100644 --- a/src/Dyna/Backend/Python/term.py +++ b/src/Dyna/Backend/Python/term.py @@ -29,11 +29,11 @@ class Term(object): return fn return '%s(%s)' % (fn, ','.join(map(_repr, self.args))) - def __getstate__(self): - return (self.fn, self.args, self.value, self.aggregator) +# def __getstate__(self): +# return (self.fn, self.args, self.value, self.aggregator) - def __setstate__(self, state): - (self.fn, self.args, self.value, self.aggregator) = state +# def __setstate__(self, state): +# (self.fn, self.args, self.value, self.aggregator) = state __add__ = __sub__ = __mul__ = notimplemented -- 2.50.1