Skip to content

Commit

Permalink
Merge pull request #41 from dimagi/type-support
Browse files Browse the repository at this point in the history
Add type support
  • Loading branch information
snopoke committed Jun 30, 2015
2 parents 08fa718 + 02d5953 commit e4052c0
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 62 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
language: python
python:
- "2.6"
- "2.7"
# - "3.3"
install:
Expand Down
3 changes: 2 additions & 1 deletion commcare_export/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def main(argv):
parser.add_argument('--verbose', default=False, action='store_true')
parser.add_argument('--output-format', default='json', choices=['json', 'csv', 'xls', 'xlsx', 'sql', 'markdown'], help='Output format')
parser.add_argument('--output', metavar='PATH', default='reports.zip', help='Path to output; defaults to `reports.zip`.')
parser.add_argument('--strict-types', default=False, action='store_true', help="When saving to a SQL database don't allow changing column types once they are created.")

args = parser.parse_args(argv)

Expand Down Expand Up @@ -137,7 +138,7 @@ def main_with_args(args):
# Writer had bizarre issues so we use a full connection instead of passing in a URL or engine
import sqlalchemy
engine = sqlalchemy.create_engine(args.output)
writer = writers.SqlTableWriter(engine.connect())
writer = writers.SqlTableWriter(engine.connect(), args.strict_types)

if not args.since and not args.start_over and os.path.exists(args.query):
connection = sqlalchemy.create_engine(args.output)
Expand Down
35 changes: 35 additions & 0 deletions commcare_export/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from jsonpath_rw import jsonpath
from jsonpath_rw.parser import parse as parse_jsonpath
from commcare_export.misc import unwrap

from commcare_export.repeatable_iterator import RepeatableIterator

Expand Down Expand Up @@ -217,6 +218,36 @@ def emitted_tables(self):
# Actual concrete environments, basically with built-in functions.
#

@unwrap
def str2bool(val):
if isinstance(val, bool):
return val
return val and str(val).lower() in {'true', 't', '1'}

@unwrap
def str2num(val):
if val is None:
return None

try:
return int(val)
except ValueError:
return float(val)


@unwrap
def str2date(val):
import dateutil.parser as parser
if not val:
return None
return parser.parse(val)


@unwrap
def bool2int(val):
return int(str2bool(val))


class BuiltInEnv(DictEnv):
"""
A built-in environment of operators and functions
Expand All @@ -241,6 +272,10 @@ def __init__(self):
'<=' : operator.__le__,
'len' : len,
'bool': bool,
'str2bool': str2bool,
'bool2int': bool2int,
'str2num': str2num,
'str2date': str2date,
})

def bind(self, name, value): raise CannotBind()
Expand Down
18 changes: 5 additions & 13 deletions commcare_export/minilinq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from jsonpath_rw import jsonpath
from jsonpath_rw.parser import parse as parse_jsonpath
from commcare_export.misc import unwrap

from commcare_export.repeatable_iterator import RepeatableIterator

Expand Down Expand Up @@ -377,21 +378,12 @@ def __init__(self, table, headings, source):
self.headings = headings
self.source = source

@unwrap
def coerce_cell_blithely(self, cell):
if isinstance(cell, jsonpath.DatumInContext):
cell = cell.value

if isinstance(cell, six.string_types):
if isinstance(cell, list):
return ','.join([self.coerce_cell(item) for item in cell])
else:
return cell
elif isinstance(cell, int):
return str(cell)
elif isinstance(cell, datetime):
return cell
elif cell is None:
return ''

# In all other cases, coerce to a list and join with ',' for now
return ','.join([self.coerce_cell(item) for item in list(cell)])

def coerce_cell(self, cell):
try:
Expand Down
29 changes: 28 additions & 1 deletion commcare_export/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes
import functools
import hashlib
import io
import json
from jsonpath_rw import jsonpath
from commcare_export.repeatable_iterator import RepeatableIterator


def digest_file(path):
with io.open(path, 'rb') as filehandle:
Expand All @@ -12,3 +15,27 @@ def digest_file(path):
break
digest.update(chunk)
return digest.hexdigest()


def unwrap(fn):
@functools.wraps(fn)
def _inner(*args):
# handle case when fn is a class method and first arg is 'self'
val = args[1] if len(args) == 2 else args[0]

if isinstance(val, RepeatableIterator):
val = list(val)

if isinstance(val, list):
if len(val) == 1:
val = val[0]
else:
val = map(_inner, val)

if isinstance(val, jsonpath.DatumInContext):
val = val.value

# call fn with 'self' if necessary
return fn(*([val] if len(args) == 1 else [args[0], val]))

return _inner
83 changes: 67 additions & 16 deletions commcare_export/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from six import StringIO, u

from itertools import chain
import datetime

logger = logging.getLogger(__name__)

Expand All @@ -19,6 +20,20 @@ def ensure_text(v):
return v
elif isinstance(v, six.binary_type):
return u(v)
elif isinstance(v, datetime.datetime):
return v.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(v, datetime.date):
return v.isoformat()
elif v is None:
return ''
else:
return u(str(v))

def to_jvalue(v):
if isinstance(v, (six.text_type,) + six.integer_types):
return v
elif isinstance(v, six.binary_type):
return u(v)
else:
return u(str(v))

Expand Down Expand Up @@ -141,7 +156,7 @@ def write_table(self, table):
# Ensures the table is iterable; probably better to create a custom JSON handler that runs in constant space
self.tables.append(dict(name=table['name'],
headings=list(table['headings']),
rows=[list(row) for row in table['rows']]))
rows=[[to_jvalue(v) for v in row] for row in table['rows']]))

class StreamingMarkdownTableWriter(TableWriter):
"""
Expand All @@ -156,7 +171,7 @@ def write_table(self, table):
self.output_stream.write('|%s|\n' % '|'.join(table['headings']))

for row in table['rows']:
self.output_stream.write('|%s|\n' % '|'.join(row))
self.output_stream.write('|%s|\n' % '|'.join(ensure_text(val) for val in row))

class SqlTableWriter(TableWriter):
"""
Expand All @@ -167,7 +182,7 @@ class SqlTableWriter(TableWriter):
MIN_VARCHAR_LEN=32 # Since SQLite does not actually support ALTER COLUMN type, let's maximize the chance that we do not have to write workarounds by starting medium
MAX_VARCHAR_LEN=255 # Arbitrary point at which we switch to TEXT; for postgres VARCHAR == TEXT anyhow and for Sqlite it doesn't matter either

def __init__(self, connection):
def __init__(self, connection, strict_types=False):
try:
import sqlalchemy
import alembic
Expand All @@ -179,6 +194,7 @@ def __init__(self, connection):
"command: pip install sqlalchemy alembic")

self.base_connection = connection
self.strict_types = strict_types

def __enter__(self):
self.connection = self.base_connection.connect() # "forks" the SqlAlchemy connection
Expand All @@ -199,10 +215,22 @@ def metadata(self):
self._metadata.reflect()
return self._metadata

@property
def is_sqllite(self):
return 'sqlite' in self.connection.engine.driver

def table(self, table_name):
return self.sqlalchemy.Table(table_name, self.metadata, autoload=True, autoload_with=self.connection)

def best_type_for(self, val):
if not self.is_sqllite:
if isinstance(val, bool):
return self.sqlalchemy.Boolean()
elif isinstance(val, datetime.datetime):
return self.sqlalchemy.DateTime()
elif isinstance(val, datetime.date):
return self.sqlalchemy.Date()

if isinstance(val, int):
return self.sqlalchemy.Integer()
elif isinstance(val, six.string_types):
Expand All @@ -225,15 +253,9 @@ def compatible(self, source_type, dest_type):
"""
Checks _coercion_ compatibility.
"""

# FIXME: Add datetime and friends
if isinstance(source_type, self.sqlalchemy.Integer):
# Integers can be cast to varch
return True

if isinstance(source_type, self.sqlalchemy.String):
if not isinstance(dest_type, self.sqlalchemy.String):
False
return False
elif source_type.length is None:
# The length being None means that we are looking at indefinite strings aka TEXT.
# This tool will never create strings with bounds, but if a target DB has one then
Expand All @@ -243,6 +265,18 @@ def compatible(self, source_type, dest_type):
else:
return (dest_type.length >= source_type.length)

compatibility = {
self.sqlalchemy.String: (self.sqlalchemy.Text,),
self.sqlalchemy.Integer: (self.sqlalchemy.String, self.sqlalchemy.Text),
self.sqlalchemy.Boolean: (self.sqlalchemy.String, self.sqlalchemy.Text),
self.sqlalchemy.DateTime: (self.sqlalchemy.String, self.sqlalchemy.Text, self.sqlalchemy.Date),
self.sqlalchemy.Date: (self.sqlalchemy.String, self.sqlalchemy.Text),
}
for _type, types in compatibility.items():
if isinstance(source_type, _type):
return isinstance(dest_type, (_type,) + types)


def least_upper_bound(self, source_type, dest_type):
"""
Returns the _coercion_ least uppper bound.
Expand All @@ -268,25 +302,42 @@ def make_table_compatible(self, table_name, row_dict):
op.create_table(table_name, id_column)
self.metadata.reflect()

def get_cols():
return {c.name: c for c in self.table(table_name).columns}

columns = get_cols()

for column, val in row_dict.items():
ty = self.best_type_for(val)

if not column in [c.name for c in self.table(table_name).columns]:
if not column in columns:
# If we are creating the column, a None crashes things even though it is the "empty" type
# but SQL does not have such a type. So we have to guess a liberal type for future use.
ty = ty or self.sqlalchemy.UnicodeText()
op.add_column(table_name, self.sqlalchemy.Column(column, ty, nullable=True))
self.metadata.clear()
self.metadata.reflect()

columns = get_cols()
else:
columns = dict([(c.name, c) for c in self.table(table_name).columns])
if val is None:
continue

current_ty = columns[column].type

if not self.compatible(ty, current_ty) and not ('sqlite' in self.connection.engine.driver):
op.alter_column(table_name, column, type_ = self.least_upper_bound(current_ty, ty))
if not self.compatible(ty, current_ty):
new_type = self.least_upper_bound(ty, current_ty)
if self.strict_types:
logger.warn('Type mismatch detected for column %s (%s != %s) '
'but strict types in use.', columns[column], current_ty, new_type)
continue
if self.is_sqllite:
logger.warn('Type mismatch detected for column %s (%s != %s) '
'but sqlite does not support changing column types', columns[column], current_ty, new_type)
continue
logger.warn('Altering column %s from %s to %s', columns[column], current_ty, new_type)
op.alter_column(table_name, column, type_=new_type)
self.metadata.clear()
self.metadata.reflect()
columns = get_cols()

def upsert(self, table, row_dict):

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def run_tests(self):
'alembic',
'argparse',
'jsonpath-rw>=1.2.1',
'openpyxl>=2.0.3',
'openpyxl<2.1.0',
'python-dateutil',
'requests',
'ndg-httpsclient',
Expand Down
18 changes: 15 additions & 3 deletions tests/test_minilinq.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ def test_eval_collapsed_list(self):
assert Apply(Reference("*"), Literal(2), Literal(3)).eval(env) == 6
assert Apply(Reference(">"), Literal(56), Literal(23.5)).eval(env) == True
assert Apply(Reference("len"), Literal([1, 2, 3])).eval(env) == 3
assert Apply(Reference("bool"), Literal('a')).eval(env) == True
assert Apply(Reference("bool"), Literal('')).eval(env) == False
assert Apply(Reference("str2bool"), Literal('true')).eval(env) == True
assert Apply(Reference("str2bool"), Literal('t')).eval(env) == True
assert Apply(Reference("str2bool"), Literal('1')).eval(env) == True
assert Apply(Reference("str2bool"), Literal('0')).eval(env) == False
assert Apply(Reference("str2bool"), Literal('false')).eval(env) == False
assert Apply(Reference("str2num"), Literal('10')).eval(env) == 10
assert Apply(Reference("str2num"), Literal('10.56')).eval(env) == 10.56
assert Apply(Reference("str2date"), Literal('2015-01-01')).eval(env) == datetime(2015, 1, 1)
assert Apply(Reference("str2date"), Literal('2015-01-01T18:32:57')).eval(env) == datetime(2015, 1, 1, 18, 32, 57)
assert Apply(Reference("str2date"), Literal('2015-01-01T18:32:57.001200')).eval(env) == datetime(2015, 1, 1, 18, 32, 57, 1200)

def test_map(self):
env = BuiltInEnv() | DictEnv({})
Expand Down Expand Up @@ -90,14 +102,14 @@ def test_flatmap(self):
pass

def test_emit(self):
env = BuiltInEnv() | JsonPathEnv({'foo': {'baz': 3}})
env = BuiltInEnv() | JsonPathEnv({'foo': {'baz': 3, 'bar': True}})
Emit(table='Foo',
headings=[Literal('foo')],
source=List([
List([ Reference('foo.baz') ])
List([ Reference('foo.baz'), Reference('foo.bar') ])
])).eval(env)

assert list(list(env.emitted_tables())[0]['rows']) == [['3']]
assert list(list(env.emitted_tables())[0]['rows']) == [[3, True]]

def test_from_jvalue(self):
assert MiniLinq.from_jvalue({"Ref": "form.log_subreport"}) == Reference("form.log_subreport")
Expand Down
Loading

0 comments on commit e4052c0

Please sign in to comment.