server/schema2sql.py
changeset 10444 fb7c1013189e
parent 10443 2d3834df64ab
child 10481 6ac4b1726e9f
--- a/server/schema2sql.py	Mon May 18 11:36:07 2015 +0200
+++ b/server/schema2sql.py	Sun Mar 22 19:39:29 2015 +0100
@@ -24,7 +24,8 @@
 from six import string_types
 from six.moves import range
 
-from yams.constraints import SizeConstraint, UniqueConstraint, Attribute
+from yams.constraints import (SizeConstraint, UniqueConstraint, Attribute,
+                              NOW, TODAY)
 
 # default are usually not handled at the sql level. If you want them, set
 # SET_DEFAULT to True
@@ -129,7 +130,7 @@
         attr = rschema.type
         rdef = rschema.rdef(eschema.type, aschema.type)
         for constraint in rdef.constraints:
-            cstrname, check = check_constraint(eschema, aschema, attr, constraint, prefix=prefix)
+            cstrname, check = check_constraint(eschema, aschema, attr, constraint, dbhelper, prefix=prefix)
             if cstrname is not None:
                 w(', CONSTRAINT %s CHECK(%s)' % (cstrname, check))
     w(');')
@@ -146,29 +147,31 @@
     w('')
     return '\n'.join(output)
 
-def check_constraint(eschema, aschema, attr, constraint, prefix=''):
+def as_sql(value, dbhelper, prefix):
+    if isinstance(value, Attribute):
+        return prefix + value.attr
+    elif isinstance(value, TODAY):
+        return dbhelper.sql_current_date()
+    elif isinstance(value, NOW):
+        return dbhelper.sql_current_timestamp()
+    else:
+        # XXX more quoting for literals?
+        return value
+
+def check_constraint(eschema, aschema, attr, constraint, dbhelper, prefix=''):
     # XXX should find a better name
     cstrname = 'cstr' + md5(eschema.type + attr + constraint.type() +
                             (constraint.serialize() or '')).hexdigest()
     if constraint.type() == 'BoundaryConstraint':
-        if isinstance(constraint.boundary, Attribute):
-            value = prefix + constraint.boundary.attr
-        else:
-            value = constraint.boundary
+        value = as_sql(constraint.boundary, dbhelper, prefix)
         return cstrname, '%s%s %s %s' % (prefix, attr, constraint.operator, value)
     elif constraint.type() == 'IntervalBoundConstraint':
         condition = []
         if constraint.minvalue is not None:
-            if isinstance(constraint.minvalue, Attribute):
-                value = prefix + constraint.minvalue.attr
-            else:
-                value = constraint.minvalue
+            value = as_sql(constraint.minvalue, dbhelper, prefix)
             condition.append('%s%s >= %s' % (prefix, attr, value))
         if constraint.maxvalue is not None:
-            if isinstance(constraint.maxvalue, Attribute):
-                value = prefix + constraint.maxvalue.attr
-            else:
-                value = constraint.maxvalue
+            value = as_sql(constraint.maxvalue, dbhelper, prefix)
             condition.append('%s%s <= %s' % (prefix, attr, value))
         return cstrname, ' AND '.join(condition)
     elif constraint.type() == 'StaticVocabularyConstraint':