]> hydra-www.ietfng.org Git - dyna2/commitdiff
Fixed Issue #66 - "stale backchaining memos when new rules are added"
authorTim Vieira <tim.f.vieira@gmail.com>
Sun, 28 Jul 2013 18:33:37 +0000 (14:33 -0400)
committerTim Vieira <tim.f.vieira@gmail.com>
Sun, 28 Jul 2013 18:33:37 +0000 (14:33 -0400)
GBC computation issues proper replacement updates in the case were there was an
old value (i.e. item was non-null).

When rules are recompiled their index stays the same.

Rules which fail to compile are reported as errors, much like rules which fail
to initialize. Rule recompilation is now one rule at a time rather than all in
one go.

Added item hash function.

Updated doctests to reflect all of these changes.

src/Dyna/Backend/Python/interpreter.py
src/Dyna/Backend/Python/term.py
test/repl/late-bc.dynadoc
test/repl/recursion-limit.dynadoc
test/repl/retract-bc.dynadoc

index 6ea4adc95453629e443f37caa56ec436567da18c..8bafae8dbd95691bca365b98ccefcdfdcab0502f 100644 (file)
@@ -60,10 +60,20 @@ def none():
 
 class Interpreter(object):
 
-    @property
-    def parser_state(self):
+    def parser_state(self, ruleix=None):
         # TODO: this is pretty hacky. XREF:parser-state
-        bc, rix, agg, other = self.pstate
+        bc, _rix, agg, other = self.pstate
+
+        # ignore ruleix in pstate. We'll manage this ourselves.
+        if ruleix is None:
+            if not self.rules:
+                rix = 0
+            else:
+                rix = max(self.rules) + 1 # next available
+            rix = max(rix, _rix)
+        else:
+            rix = ruleix  # override
+
         lines = [':-ruleix %d.' % rix]
         for fn in bc:
             [(fn, arity)] = re.findall('(.*)/(\d+)', fn)
@@ -159,40 +169,6 @@ class Interpreter(object):
         self.agenda[item] = self.time_step
         self.time_step += 1
 
-    def replace(self, item, now):
-        "replace current value of ``item``, propagate any changes."
-        was = item.value
-        if was == now:
-            # nothing to do.
-            return
-
-        # special handling for with_key, forks a second update
-        k = self.build('$key/1', item)
-
-        if hasattr(now, 'fn') and now.fn == 'with_key/2':
-            now, key = now.args
-            self.replace(k, key)
-            if was == now:
-                return
-        else:
-            # retract $key when we retract the item or no longer have a with_key
-            # as the value.
-            if k.value is not None:
-                self.replace(k, None)
-
-        # delete existing value before so we can replace it
-        if was is not None:
-            self.push(item, was, delete=True)
-        # clear existing errors -- we only care about errors at new value.
-        self.clear_error(item)
-        # new value enters in the chart.
-        item.value = now
-        # push changes
-        if now is not None:
-            self.push(item, now, delete=False)
-        # make note of change
-        self.changed[item] = now
-
     def run_agenda(self):
         self.changed = {}
         try:
@@ -217,25 +193,51 @@ class Interpreter(object):
         Handle popping `item`: fold `item`'s aggregator to get it's new value
         (handle errors), propagate changes to the rest of the circuit.
         """
-
         if item.aggregator is None:
             return item
-
         try:
             # compute item's new value
             now = item.aggregator.fold()
-
         except (AggregatorError, ZeroDivisionError, ValueError, TypeError, OverflowError) as e:
             # handle error in aggregator
             now = Error()
             self.replace(item, now)
             self.set_error(item, (None, [(e, None)]))
-
         else:
             self.replace(item, now)
-
         return now
 
+    def replace(self, item, now):
+        "replace current value of ``item``, propagate any changes."
+        was = item.value
+        if was == now:
+            # nothing to do.
+            return
+        # special handling for with_key, forks a second update
+        k = self.build('$key/1', item)
+        if hasattr(now, 'fn') and now.fn == 'with_key/2':
+            now, key = now.args
+            self.replace(k, key)
+            if was == now:
+                return
+        else:
+            # retract $key when we retract the item or no longer have a with_key
+            # as the value.
+            if k.value is not None:
+                self.replace(k, None)
+        # delete existing value before so we can replace it
+        if was is not None:
+            self.push(item, was, delete=True)
+        # clear existing errors -- we only care about errors at new value.
+        self.clear_error(item)
+        # new value enters in the chart.
+        item.value = now
+        # push changes
+        if now is not None:
+            self.push(item, now, delete=False)
+        # make note of change
+        self.changed[item] = now
+
     def push(self, item, val, delete):
         """
         Passes update to relevant handlers. Catches errors.
@@ -281,9 +283,16 @@ class Interpreter(object):
         if item.value is not None:
             return item.value
 
+        return self.force_gbc(item)
+
+    def force_gbc(self, item):
+        "Skips memo on item check."
+
         if item.aggregator is None:   # we might not have a rule defining this subgoal.
             return
 
+        self.clear_error(item)
+
         item.aggregator.clear()
 
         emits = []
@@ -294,7 +303,7 @@ class Interpreter(object):
 
         for handler in self._gbc[item.fn]:
             try:
-                handler(*args, emit=t_emit)
+                handler(*item.args, emit=t_emit)
             except (ZeroDivisionError, ValueError, TypeError, RuntimeError, OverflowError) as e:
                 e.exception_frame = rule_error_context()
                 e.traceback = traceback.format_exc()
@@ -304,14 +313,12 @@ class Interpreter(object):
             self.set_error(item, (None, errors))
             return Error()
 
-        else:
-            for e in emits:
-                # an error could happen here, but we assume (by contract) that this
-                # is not possible.
-                self.emit(*e)
-            return self.pop(item)
+        for e in emits:
+            self.emit(*e)
+
+        return self.pop(item)
 
-    def load_plan(self, filename):
+    def load_plan(self, filename, recurse=True):
         """
         Compile, load, and execute new dyna rules.
 
@@ -348,49 +355,70 @@ class Interpreter(object):
             self.add_rule(index, query=h, head_fn=fn, anf=anf[index])
 
         for index, h in env.initializers:
-            assert index not in self.rules
             new_rules.add(index)
             self.add_rule(index, init=h, anf=anf[index])
 
         for fn, r, h in env.updaters:
             self.new_updater(fn, r, h)
 
-        # run to fixed point.
-        if self.recompile:
-            try:
-                plan = self.dynac_code('\n'.join(r.src for r in sorted(self.recompile, key=lambda r: r.index)))
-            except DynaCompilerError as e:
-                # TODO: it's a bit strange to ignore the error and just print
-                # it. However, since the rules in the recompile list are
-                # syntactically valid (well, they at least they were valid) --
-                # this means that errors must be planning errors... probably all
-                # to do with missing BC declarations.
-                #
-                # TODO: should probably at the very least report compiler errors
-                # in a similar fashion to initialization errors.
-                #
-                # TODO: we probably have to worry about infinite loops -- at the
-                # moment this results in an interpreter crash due to max
-                # recursion limit
-                print e
-            else:
-                # TODO: reuse old rule index when we recompile.
-                for r in self.recompile:
-                    self.retract_rule(r.index)
-                self.recompile.clear()
-                self.load_plan(plan)
-
-                # TODO: should probably indicate that some rules were recompiled
-                # and no longer in an error state. -- there is a bit of a
-                # mismatch with when we choose not to add a rule... in the try
-                # block above we reject rules on compiler error, but if the
-                # rules existed before and it now longer compiles we just print
-                # the error and add it to the recompile list.
+        if self.recompile and recurse:
+
+            # TODO: it's a bit strange to ignore the error and just print
+            # it. However, since the rules in the recompile list are
+            # syntactically valid (well, they at least they were valid) -- this
+            # means that errors must be planning errors... probably all to do
+            # with missing BC declarations.
+            #
+            # TODO: should probably at the very least report compiler errors in
+            # a similar fashion to initialization errors.
+            #
+            # TODO: we probably have to worry about infinite loops -- at the
+            # moment this results in an interpreter crash due to max recursion
+            # limit
+            #
+            # TODO: should probably indicate that some rules were recompiled and
+            # no longer in an error state. -- there is a bit of a mismatch with
+            # when we choose not to add a rule... in the try block above we
+            # reject rules on compiler error, but if the rules existed before
+            # and it now longer compiles we just print the error and add it to
+            # the recompile list.
+            #
+            # TODO: maybe we should handle recompilation more like we do
+            # uninitialized rules.
+
+            # run to fixed point.
+            while True:
+                failed = set()
+                for r in list(self.recompile):
+                    try:
+                        plan = self.recompile_rule(r)
+                    except DynaCompilerError as e:
+                        failed.add(r)
+                        self.set_error(r, e)
+                    else:
+                        self.retract_rule(r.index)
+                        self.load_plan(plan, recurse=False)
+                if failed == self.recompile:   # no progress
+                    break
+                self.recompile = failed
 
         # we we don't accumulate all changed rules, new_rules return will be new
         # rules of top-level call.
         return new_rules
 
+    def recompile_rule(self, r):
+        "returns a plan, it's up to you to retract the old rule and load the plan"
+        pstate = self.parser_state(ruleix=r.index)   # override ruleix
+        code = r.src
+        x = sha1()
+        x.update(pstate)
+        x.update(code)
+        dyna = self.tmp / ('%s.dyna' % x.hexdigest())
+        with file(dyna, 'wb') as f:
+            f.write(pstate)
+            f.write(code)
+        return self.dynac(dyna)
+
     def run_uninitialized(self):
         q = set(self.uninitialized_rules)
         failed = []
@@ -401,10 +429,8 @@ class Interpreter(object):
                 emits = []
                 def _emit(*args):
                     emits.append(args)
-
                 self.clear_error(rule)  # clear errors on rule, if any
                 rule.init(emit=_emit)
-
             except (ZeroDivisionError, ValueError, TypeError, RuntimeError, OverflowError) as e:
                 e.exception_frame = rule_error_context()
                 e.traceback = traceback.format_exc()
@@ -412,7 +438,6 @@ class Interpreter(object):
                 failed.append(rule)
             else:
                 rule.initialized = True
-                # process emits
                 for e in emits:
                     self.emit(*e, delete=False)
         self.uninitialized_rules = failed
@@ -448,16 +473,21 @@ class Interpreter(object):
             assert False, 'did not find head'
         assert head_fn is not None
 
+        # TODO: have backend send this information alongside the rule.
         span = hide_ugly_filename(parse_attrs(init or query)['Span'])
         dyna_src = strip_comments(parse_attrs(init or query)['rule'])
 
-        rule = self.rules[index] = Rule(index)
+        rule = Rule(index)
 
         rule.span = span
         rule.src = dyna_src
         rule.anf = anf
         rule.head_fn = head_fn
 
+        #assert index not in self.rules, {'new': rule, 'old': self.rules[index]}
+
+        self.rules[index] = rule
+
         self.update_coarse(rule)
 
         if init:
@@ -598,12 +628,7 @@ class Interpreter(object):
         visited.add(fn)
 
         for x in self.chart[fn].intern.values():
-            self.clear_error(x)
-            was = x.value                # remember old value we can do proper replacement update
-            x.value = None               # ignore memo
-            now = self.gbc(fn, *x.args)
-            x.value = was
-            self.replace(x, now)
+            self.force_gbc(x)
 
         # recompute dependent BC memos
         for v in self.coarse_deps[fn]:
@@ -636,12 +661,13 @@ class Interpreter(object):
 
     def dynac_code(self, code):
         "Compile a string of dyna code."
+        pstate = self.parser_state()
         x = sha1()
-        x.update(self.parser_state)
+        x.update(pstate)
         x.update(code)
         dyna = self.tmp / ('%s.dyna' % x.hexdigest())
         with file(dyna, 'wb') as f:
-            f.write(self.parser_state)  # include parser state if any.
+            f.write(pstate)  # include parser state if any.
             f.write(code)
         return self.dynac(dyna)
 
@@ -691,13 +717,13 @@ class Interpreter(object):
             (val, es) = x
             for e, h in es:
                 if h is None:
-                    I[item.fn][type(e)].append((item, val, e))
+                    I[item.fn][type(e)].append((item, e))
                 else:
                     assert h.rule.index in self.rules
                     E[h.rule][type(e)].append((item, val, e))
 
         # We only dump the error chart if it's non empty.
-        if not I and not E and not self.uninitialized_rules:
+        if not I and not E and not self.uninitialized_rules and not self.recompile:
             return
 
         print >> out
@@ -705,13 +731,13 @@ class Interpreter(object):
         print >> out, red % '======'
 
         # aggregation errors
-        for r in sorted(I, key=lambda r: r.index):
-            print >> out, 'Error(s) aggregating %s:' % r
-            for etype in I[r]:
+        for fn in sorted(I):
+            print >> out, 'Error(s) aggregating %s:' % fn
+            for etype in I[fn]:
                 print >> out, '  %s:' % etype.__name__
-                for i, (item, value, e) in enumerate(sorted(I[r][etype])):
+                for i, (item, e) in enumerate(sorted(I[fn][etype])):
                     if i >= 5:
-                        print >> out, '    %s more ...' % (len(I[r][etype]) - i)
+                        print >> out, '    %s more ...' % (len(I[fn][etype]) - i)
                         break
                     print >> out, '    `%s`: %s' % (item, e)
                 print >> out
@@ -732,8 +758,8 @@ class Interpreter(object):
                     print >> out, '    when `%s` = %s' % (item, _repr(value))
 
                     if 'maximum recursion depth exceeded' in str(e):
-                        # simplify recurision limit error because if prints some
-                        # unstable stuff.
+                        # simplify recurision limit error because it prints some
+                        # unpredictable stuff.
                         print >> out, '      maximum recursion depth exceeded'
                     else:
                         print >> out, '      %s' % (e)
@@ -746,7 +772,7 @@ class Interpreter(object):
         if self.uninitialized_rules:
             print >> out, red % 'Uninitialized rules'
             print >> out, red % '==================='
-            for rule in self.uninitialized_rules:
+            for rule in sorted(self.uninitialized_rules, key=lambda r: r.index):
                 e = self.error[rule]
                 print >> out, 'Failed to initialize rule:'
                 print >> out, '   ', rule.src
@@ -754,6 +780,19 @@ class Interpreter(object):
                 print >> out, rule.render_ctx(e.exception_frame, indent='    ')
                 print >> out
 
+        # rules which failed to recompile
+        if self.recompile:
+            print >> out, red % 'Failed to recompile'
+            print >> out, red % '==================='
+            for rule in sorted(self.recompile, key=lambda r: r.index):
+                e = self.error[rule]
+                print >> out, 'Failed to recompile rule:'
+                print >> out, '   ', rule.src
+                print >> out, '  with error'
+                for line in str(e).split('\n'):
+                    print >> out, '   ', line
+                print >> out
+
         print >> out
 
     def dump_rules(self):
index 38a3ea97ddf161c658610667210f47a5b453fb4a..b31d7f3cef2db3b3b739bc5befd193639290af54 100644 (file)
@@ -14,18 +14,19 @@ class Term(object):
         self.aggregator = None
 
     def __eq__(self, other):
-        if other is None:
-            return False
-        if not isinstance(other, Term):
+        try:
+            return (self.fn, self.args) == (other.fn, other.args)
+        except AttributeError:
             return False
-        return self.fn == other.fn and self.args == other.args
+
+    def __hash__(self):
+        return hash((self.fn, self.args))
 
     def __cmp__(self, other):
-        if other is None:
-            return 1
-        if not isinstance(other, Term):
+        try:
+            return cmp((self.fn, self.args), (other.fn, other.args))
+        except AttributeError:
             return 1
-        return cmp((self.fn, self.args), (other.fn, other.args))
 
     def __repr__(self):
         "Pretty print a term. Will retrieve the complete (ground) term."
@@ -34,24 +35,6 @@ class Term(object):
             return fn
         return '%s(%s)' % (fn, ','.join(map(_repr, self.args)))
 
-#    def __getstate__(self):
-#        return (self.fn, self.args, self.value, self.aggregator)
-
-#    def __setstate__(self, state):
-#        (self.fn, self.args, self.value, self.aggregator) = state
-
-#    def __add__(self, _):
-#        raise TypeError("Can't subtract terms.")
-
-#    def __sub__(self, _):
-#        raise TypeError("Can't add terms.")
-
-#    def __mul__(self, _):
-#        raise TypeError("Can't multiply terms.")
-
-#    def __div__(self, _):
-#        raise TypeError("Can't divide terms.")
-
 
 class Cons(Term):
 
@@ -83,14 +66,14 @@ class Cons(Term):
     def __iter__(self):
         return iter(self.aslist)
 
-    def __eq__(self, other):
-        try:
-            return self.aslist == other.aslist
-        except AttributeError:
-            return False
+#    def __eq__(self, other):
+#        try:
+#            return self.aslist == other.aslist
+#        except AttributeError:
+#            return False
 
-    def __hash__(self):
-        return hash(tuple(self.aslist))
+#    def __hash__(self):
+#        return hash(tuple(self.aslist))
 
 #    def __cmp__(self, other):
 #        try:
@@ -123,11 +106,11 @@ class _Nil(Term):
     def __iter__(self):
         return iter([])
 
-    def __eq__(self, other):
-        try:
-            return self.aslist == other.aslist
-        except AttributeError:
-            return False
+#    def __eq__(self, other):
+#        try:
+#            return self.aslist == other.aslist
+#        except AttributeError:
+#            return False
 
 #    def __cmp__(self, other):
 #        try:
index 19ceb6b686720f766b5178b025bc13083490b740..620cb75bc781a80a330e3da3c1b50cd200a59c58 100644 (file)
@@ -16,8 +16,8 @@ a(4) = 4.
 
 Rules
 =====
+  0: a(X) = f(X) for X in range(1,5).
   1: f(X) = X.
-  2: a(X) = f(X) for X in range(1,5).  % originally index 0, now 2.
 
 
 > foo(X) += bar(X).
@@ -26,14 +26,47 @@ Rules
 > :- backchain bar/1.
 > bar(X) = X+1.
 
-Encountered error in input program:
- Unable to plan initializers for rule(s):
-  foo(X) += bar(X).
-  foo(X) += 2*bar(X).
-Everything was syntactically valid, but we could not
-see it through.
+>>> 2 new errors. Type `sol` for details.
+
+> sol
+
+Solution
+========
+a/1
+===
+a(1) = 1.
+a(2) = 2.
+a(3) = 3.
+a(4) = 4.
+
+
+Errors
+======
+
+Failed to recompile
+===================
+Failed to recompile rule:
+    foo(X) += bar(X).
+  with error
+    Encountered error in input program:
+     Unable to plan initializers for rule(s):
+      foo(X) += bar(X).
+    Everything was syntactically valid, but we could not
+    see it through.
+
+Failed to recompile rule:
+    foo(X) += 2*bar(X).
+  with error
+    Encountered error in input program:
+     Unable to plan initializers for rule(s):
+      foo(X) += 2*bar(X).
+    Everything was syntactically valid, but we could not
+    see it through.
 
 > :- backchain foo/1.
+
+>>> 2 errors cleared.
+
 > query foo(3)
 
 foo(3) = 12.
index 0c7cb9455f0559c7499966e00913bf8112edf68d..68ca398593959ebc21005a493749c5847d50045c 100644 (file)
@@ -38,9 +38,9 @@ Errors
 Error(s) in rule 0: <repl>
     f(X) = f(X-1).
   RuntimeError:
-    when `f(-323)` = null
+    when `f(-241)` = null
       maximum recursion depth exceeded
-      f(X=-323) = f((X=-323 - 1)=-324)=?.
+      f(X=-241) = f((X=-241 - 1)=-242)=?.
 
 Uninitialized rules
 ===================
index 0f5ec6304e0f73e0be9d0436cbdfc5be49e15bd6..8b64f314c93e946b93187f71cab9dd78d46739d9 100644 (file)
@@ -54,7 +54,6 @@ a(1) = 1.
 > :- backchain f/1.
 | f(X) := f(X-1) * X for X > 1.
 | f(0) := 1.
-| f(0) := 5.
 | b(X) = f(X) for X in range(6).
 
 % Check that `a(x)` and `s`, know that we have a new definition! for `f`, just
@@ -62,27 +61,28 @@ a(1) = 1.
 
 Changes
 =======
+a(0) = 1.  % a(1) doesn't change.
 a(2) = 2.
 a(3) = 6.
 a(4) = 24.
 a(5) = 120.
-b(0) = 0.
+b(0) = 1.
 b(1) = 1.
 b(2) = 2.
 b(3) = 6.
 b(4) = 24.
 b(5) = 120.
-s = 153.
+s = 154.
 
 > sol
 
 Solution
 ========
-s = 153.
+s = 154.
 
 a/1
 ===
-a(0) = 0.
+a(0) = 1.
 a(1) = 1.
 a(2) = 2.
 a(3) = 6.
@@ -91,7 +91,7 @@ a(5) = 120.
 
 b/1
 ===
-b(0) = 0.
+b(0) = 1.
 b(1) = 1.
 b(2) = 2.
 b(3) = 6.