]> hydra-www.ietfng.org Git - dyna2/commitdiff
Proposal for definition of `argmax=/argmin=`.
authorTim Vieira <tim.f.vieira@gmail.com>
Tue, 2 Jul 2013 21:46:07 +0000 (17:46 -0400)
committerTim Vieira <tim.f.vieira@gmail.com>
Tue, 2 Jul 2013 21:46:07 +0000 (17:46 -0400)
src/Dyna/Backend/Python/Backend.hs
src/Dyna/Backend/Python/defn.py
src/Dyna/Backend/Python/interpreter.py
src/Dyna/Backend/Python/post/trace.py
src/Dyna/Backend/Python/stdlib.py

index dd0be5688d87515406a8412a4e30e11cc01b8978..2de761739f2b17c64085056be421a752f241f5c4 100644 (file)
@@ -49,7 +49,7 @@ import           Text.PrettyPrint.Free
 
 aggrs :: S.Set String
 aggrs = S.fromList
-  [ "max=" , "min="
+  [ "max=" , "min=", "argmax=", "argmin="
   , "+=" , "*="
   , "and=" , "or=" , "&=" , "|="
   , ":-"
index 6a6bd893d0f4f0cdd51fbb7a791238b9f4aa1161..a0a33563e37175af524028af451701a81057e287 100644 (file)
@@ -144,6 +144,24 @@ class min_equals(BAggregator):
         if len(s):
             return min(s)
 
+
+class argmax_equals(max_equals):
+    def fold(self):
+        m = max_equals.fold(self)
+        if m:
+            if not hasattr(m, 'aslist') or len(m.aslist) != 2:
+                raise AggregatorError("argmax expects a pair of values")
+            return m.aslist[1]
+
+class argmin_equals(min_equals):
+    def fold(self):
+        m = min_equals.fold(self)
+        if m:
+            if not hasattr(m, 'aslist') or len(m.aslist) != 2:
+                raise AggregatorError("argmin expects a pair of values")
+            return m.aslist[1]
+
+
 class plus_equals(BAggregator):
     def fold(self):
         s = [k*m for k, m in self.iteritems() if m != 0]
@@ -208,6 +226,8 @@ defs = {
     'set=': set_equals,
     'bag=': bag_equals,
     'mean=': mean_equals,
+    'argmax=': argmax_equals,
+    'argmin=': argmin_equals,
 }
 
 def aggregator(name):
index e07919561093c7d8520df48881de500dbebe5dd7..2cde74710f7cd07a71073fba2ce1b99daa3d70c5 100644 (file)
@@ -207,6 +207,7 @@ class Interpreter(object):
         self.do(self.dynac_code(code), initialize=False)
 
     def new_fn(self, fn, agg):
+
         # check for aggregator conflict.
         if self.agg_name[fn] is None:
             self.agg_name[fn] = agg
@@ -453,7 +454,21 @@ class Interpreter(object):
     def delete_emit(self, item, val, ruleix, variables):
         self.emit(item, val, ruleix, variables, delete=True)
 
-    def emit(self, item, val, ruleix, variables, delete):
+    def emit(self, item, val, ruleix, variables, delete): #, aggregator_to_inherit=None):
+
+#        if item.fn == 'cons/2':
+#            assert isinstance(val, Term) \
+#                and val.fn == 'cons/2' \
+#                and len(val.aslist) == len(item.aslist)
+#            # recurse.
+#            for x, v in zip(item.aslist, val.aslist):
+#                self.emit(x, v, ruleix, variables, delete,
+#                          aggregator_to_inherit=self.rules[ruleix].anf.agg)
+#            return
+#        assert item.fn != 'cons/2' and item.fn != 'nil/0'
+#        if item.aggregator is None:
+#            self.new_fn(item.fn, aggregator_to_inherit)
+
         if delete:
             item.aggregator.dec(val, ruleix, variables)
         else:
@@ -476,8 +491,15 @@ class Interpreter(object):
         """
         assert os.path.exists(filename)
 
+
+
         env = imp.load_source('dynamically_loaded_module', filename)
 
+        if path(filename + '.anf').exists():       # XXX: should have codegen provide this in plan.py
+            with file(filename + '.anf') as f:
+                for anf in read_anf(f.read()):
+                    self.rules[anf.ruleix].anf = anf
+
         for k,v in [('chart', self.chart),
                     ('build', self.build),
                     ('gbc', self.gbc),
@@ -524,11 +546,6 @@ class Interpreter(object):
             for e in emits:
                 self.emit(*e, delete=False)
 
-        if path(filename + '.anf').exists():       # XXX: should have codegen provide this in plan.py
-            with file(filename + '.anf') as f:
-                for anf in read_anf(f.read()):
-                    self.rules[anf.ruleix].anf = anf
-
         return self.go()
 
     def dynac(self, filename):
index a01d60131ce36fbb0b5aeed5c2f2fd65cdb5f1d6..302424b18d5f4704b81315444c3a14531862f1c8 100644 (file)
@@ -1,5 +1,9 @@
 # -*- coding: utf-8 -*-
 
+"""
+TODO: have ANF output which functors are infix, prefix, nullary, etc.
+"""
+
 import re
 from collections import defaultdict
 
index 006c4bf11a3fd07e120cdc5d4467ec070a0a7532..2f47eb2b3e0f29fed35a483b1a8280ad2a69c6d5 100644 (file)
@@ -43,3 +43,6 @@ def _todynalist(x):
     if not x:
         return Nil
     return Cons(x[0], _todynalist(x[1:]))
+
+def get(x, i):
+    return x[i]