diff --git a/.github/workflows/release_actions.yml b/.github/workflows/release_actions.yml new file mode 100644 index 00000000..d4fee0bd --- /dev/null +++ b/.github/workflows/release_actions.yml @@ -0,0 +1,53 @@ +name: commcare-export release actions +on: + release: + types: [published] + +jobs: + generate_linux_bin: + name: Generate Linux binary as release asset + runs-on: ubuntu-22.04 + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Install pyinstaller + shell: bash + run: python -m pip install pyinstaller + + - name: Generate exe + shell: bash + run: | + pip install commcare-export + pip install -r build_exe/requirements.txt + pyinstaller --dist ./dist/linux commcare-export.spec + + - name: Upload release assets + uses: AButler/upload-release-assets@v3.0 + with: + files: "./dist/linux/*" + repo-token: ${{ secrets.GITHUB_TOKEN }} + + generate_windows_exe: + name: Generate Windows exe as release asset + runs-on: windows-latest + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Install pyinstaller + shell: pwsh + run: python -m pip install pyinstaller + + - name: Generate exe + shell: pwsh + run: | + pip install commcare-export + pip install -r build_exe/requirements.txt + pyinstaller --dist ./dist/windows commcare-export.spec + + - name: Upload release assets + uses: AButler/upload-release-assets@v3.0 + with: + files: "./dist/windows/*" + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..31f5aa40 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,72 @@ +name: commcare-export tests +on: + pull_request: + branches: + - master +env: + DB_USER: db_user + DB_PASSWORD: Password123 +jobs: + test: + runs-on: ubuntu-22.04 + services: + mssql: + image: mcr.microsoft.com/mssql/server:2017-latest + env: + SA_PASSWORD: ${{ env.DB_PASSWORD }} + ACCEPT_EULA: 'Y' + ports: + - 1433:1433 + postgres: + image: postgres + env: + POSTGRES_PASSWORD: ${{ env.DB_PASSWORD }} + POSTGRES_USER: ${{ env.DB_USER }} + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 50 + - run: git fetch --tags origin # So we can use git describe. actions/checkout@v3 does not pull tags. + + # MySQL set up + - run: sudo service mysql start # Ubuntu already includes mysql no need to use service + - run: mysql -uroot -proot -e "CREATE USER '${{ env.DB_USER }}'@'%';" + - run: mysql -uroot -proot -e "GRANT ALL PRIVILEGES ON *.* TO '${{ env.DB_USER }}'@'%';" + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - run: sudo apt-get install pandoc + - run: pip install --upgrade pip + - run: pip install setuptools + - run: python setup.py sdist + - run: pip install dist/* + - run: pip install pymysql psycopg2 pyodbc + - run: pip install coverage coveralls + - run: pip install mypy + - run: pip install pytest + - run: pip install -e ".[test]" + - run: coverage run setup.py test + env: + POSTGRES_URL: postgresql://${{ env.DB_USER }}:${{ env.DB_PASSWORD }}@localhost/ + MYSQL_URL: mysql+pymysql://${{ env.DB_USER }}:@localhost/ + MSSQL_URL: mssql+pyodbc://sa:${{ env.DB_PASSWORD }}@localhost/ + HQ_USERNAME: ${{ secrets.HQ_USERNAME }} + HQ_API_KEY: ${{ secrets.HQ_API_KEY }} + - run: mypy --install-types --non-interactive @mypy_typed_modules.txt + - run: coverage lcov -o coverage/lcov.info + - name: Coveralls + uses: coverallsapp/github-action@v2 + with: + github-token: + ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index fb47a4b9..69607db9 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,4 @@ nosetests.xml # Excel ~*.xlsx +commcare_export.log \ No newline at end of file diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 00000000..d09f55c8 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,5 @@ +# https://github.com/timothycrosley/isort/wiki/isort-Settings +[settings] +multi_line_output=3 +include_trailing_comma=true + diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 00000000..d7bee203 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,11 @@ +[style] +based_on_style = yapf +column_limit = 79 +indent_width = 4 +coalesce_brackets = true +dedent_closing_brackets = true +spaces_before_comment = 2 +split_before_arithmetic_operator = true +split_before_bitwise_operator = true +split_before_logical_operator = true +split_all_top_level_comma_separated_values = true diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 1c692d69..00000000 --- a/.travis.yml +++ /dev/null @@ -1,42 +0,0 @@ -language: python -sudo: required -dist: "xenial" -python: - - "2.7" - - "3.6" - - "3.7" -addons: - apt: - packages: - - unixodbc-dev -env: - global: - - MSSQL_SA_PASSWORD=Password@123 - # HQ_USERNAME and HQ_API_KEY - - secure: etv02uWtyy5P4DfyuHjFm5RDFc6WBHLsnIMC75VjDk8kxDgwV/lDbPYMh/hzfPHyskgA1EQbc8IfHlbZWFVV8jOTy+wvrVir/mw95AEyNyAL/TTSWvYfTvdCsxOSbY6vcGlJNfy6rc+y0h6QyuIknY0OhU8sTaRcQnvbFPnOz28= - - secure: aLj1bKtUF2CnAwG+yjiAjo39cKi9WHaonIwqsuhOx4McsD/xSz4QHv/6/XhXZ5KxKyxw1+PBl/mWo6gyrT5iHDRBPk5iJXqZAgQFS2ukZSv/tUBGL7bWzoO9YfoLuWllA33DCr3PiXAhkH53dTcor16UN9wXeCprBBSGjhpAxRQ= -before_install: - - docker pull microsoft/mssql-server-linux:2017-latest - - docker run -e "ACCEPT_EULA=Y" -e "MSSQL_SA_PASSWORD=$MSSQL_SA_PASSWORD" -p 1433:1433 --name mssql1 -d microsoft/mssql-server-linux:2017-latest - - curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add - - - echo "deb [arch=amd64] https://packages.microsoft.com/ubuntu/14.04/prod trusty main" | sudo tee /etc/apt/sources.list.d/mssql-release.list - - sudo apt-get update -qq -install: - - sudo apt-get install pandoc - - python setup.py sdist - - pip install dist/* - - pip install pymysql psycopg2 pyodbc - - pip install coverage coveralls - - sudo ACCEPT_EULA=Y apt-get install msodbcsql17 -before_script: - - mysql -u root -e "GRANT ALL PRIVILEGES ON *.* TO 'travis'@'%';"; - - docker ps -a - - odbcinst -q -d - - .travis/wait.sh -script: coverage run setup.py test -after_success: - - coveralls -services: - - postgresql - - mysql - - docker diff --git a/.travis/wait.sh b/.travis/wait.sh deleted file mode 100755 index 67497e45..00000000 --- a/.travis/wait.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -echo "Waiting MSSQL docker to launch on 1433..." - -while ! nc -z localhost 1433; do - sleep 0.1 -done - -echo "MSSQL launched" diff --git a/README.md b/README.md index f3ef96be..b9ace252 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,11 @@ CommCare Export https://github.com/dimagi/commcare-export -[![Build Status](https://travis-ci.org/dimagi/commcare-export.png)](https://travis-ci.org/dimagi/commcare-export) +[![Build Status](https://app.travis-ci.com/dimagi/commcare-export.svg?branch=master)](https://app.travis-ci.com/dimagi/commcare-export) [![Test coverage](https://coveralls.io/repos/dimagi/commcare-export/badge.png?branch=master)](https://coveralls.io/r/dimagi/commcare-export) [![PyPI version](https://badge.fury.io/py/commcare-export.svg)](https://badge.fury.io/py/commcare-export) -A command-line tool (and Python library) to generate customized exports from the [CommCareHQ](https://www.commcarehq.org) [REST API](https://wiki.commcarehq.org/display/commcarepublic/Data+APIs). +A command-line tool (and Python library) to generate customized exports from the [CommCare HQ](https://www.commcarehq.org) [REST API](https://wiki.commcarehq.org/display/commcarepublic/Data+APIs). * [User documentation](https://wiki.commcarehq.org/display/commcarepublic/CommCare+Data+Export+Tool) * [Changelog](https://github.com/dimagi/commcare-export/releases) @@ -15,28 +15,84 @@ A command-line tool (and Python library) to generate customized exports from the Installation & Quick Start -------------------------- -0a\. Install Python and `pip`. This tool is [tested with Python 2.7, 3.6 and 3.7](https://travis-ci.org/dimagi/commcare-export). +Following commands are to be run on a terminal or a command line. -0b\. Sign up for [CommCareHQ](https://www.commcarehq.org/) if you have not already. +Once on a terminal window or command line, for simplicity, run commands from the home directory. -1\. Install CommCare Export via `pip` +### Python +Check which Python version is installed. + +This tool is tested with Python versions from 3.8 to 3.12. + +```shell +$ python3 --version +``` +If Python is installed, its version will be shown. + +If Python isn't installed, [download and install](https://www.python.org/downloads/) +a version of Python from 3.8 to 3.12. + +## Virtualenv (Optional) + +It is recommended to set up a virtual environment for CommCare Export +to avoid conflicts with other Python applications. + +More about virtualenvs on https://docs.python.org/3/tutorial/venv.html + +Setup a virtual environment using: + +```shell +$ python3 -m venv venv ``` + +Activate virtual environment by running: + +```shell +$ source venv/bin/activate +``` + +**Note**: virtualenv needs to be activated each time you start a new terminal session or command line prompt. + +For convenience, to avoid doing that, you can create an alias to activate virtual environments in +"venv" directory by adding the following to your +`.bashrc` or `.zshrc` file: + +```shell +$ alias venv='if [[ -d venv ]] ; then source venv/bin/activate ; fi' +``` + +Then you can activate virtual environments with simply typing +```shell +$ venv +``` + +## Install CommCare Export + +Install CommCare Export via `pip` + +```shell $ pip install commcare-export ``` -2\. Create a project space and application. +## CommCare HQ -3\. Visit the Release Manager, make a build, click the star to release it. +1. Sign up for [CommCare HQ](https://www.commcarehq.org/) if you have not already. -4\. Use Web Apps and fill out some forms. +2. Create a project space and application. -5\. Modify one of example queries in the `examples/` directory, modifying the "Filter Value" column +3. Visit the Release Manager, make a build, click the star to release it. + +4. Use Web Apps and fill out some forms. + +5. Modify one of example queries in the `examples/` directory, modifying the "Filter Value" column to match your form XMLNS / case type. See [this page](https://confluence.dimagi.com/display/commcarepublic/Finding+a+Form%27s+XMLNS) to determine the XMLNS for your form. -``` +Now you can run the following examples: + +```shell $ commcare-export \ --query examples/demo-registration.xlsx \ --project YOUR_PROJECT \ @@ -61,7 +117,7 @@ $ commcare-export \ You'll see the tables printed out. Change to `--output-format sql --output URL_TO_YOUR_DB --since DATE` to sync all forms submitted since that date. -All examples are present in Excel and also equivalent JSON, however it is recommended +Example query files are provided in both Excel and JSON format. It is recommended to use the Excel format as the JSON format may change upon future library releases. Command-line Usage @@ -69,7 +125,7 @@ Command-line Usage The basic usage of the command-line tool is with a saved Excel or JSON query (see how to write these, below) -``` +```shell $ commcare-export --commcare-hq \ --username \ --project \ @@ -85,7 +141,7 @@ $ commcare-export --commcare-hq \ See `commcare-export --help` for the full list of options. -There are example query files for the CommCare Demo App (available on the CommCareHq Exchange) in the `examples/` +There are example query files for the CommCare Demo App (available on the CommCare HQ Exchange) in the `examples/` directory. `--output` @@ -107,7 +163,7 @@ mssql+pyodbc://scott:tiger@localhost/mydatabases?driver=ODBC+Driver+17+for+SQL+S Excel Queries ------------- -An excel query is any `.xlsx` workbook. Each sheet in the workbook represents one table you wish +An Excel query is any `.xlsx` workbook. Each sheet in the workbook represents one table you wish to create. There are two grouping of columns to configure the table: - **Data Source**: Set this to `form` to export form data, or `case` for case data. @@ -120,7 +176,7 @@ JSON Queries ------------ JSON queries are a described in the table below. You build a JSON object that represents the query you have in mind. -A good way to get started is to work from the examples, or you could make an excel query and run the tool +A good way to get started is to work from the examples, or you could make an Excel query and run the tool with `--dump-query` to see the resulting JSON query. @@ -131,24 +187,24 @@ The --users and --locations options export data from a CommCare project that can be joined with form and case data. The --with-organization option does all of that and adds a field to Excel query specifications to be joined on. -Specifiying the --users option or --with-organization option will export an +Specifying the --users option or --with-organization option will export an additional table named 'commcare_users' containing the following columns: -Column | Type | Note ------- | ---- | ---- -id | Text | Primary key -default_phone_number | Text | -email | Text | -first_name | Text | -groups | Text | -last_name | Text | -phone_numbers | Text | -resource_uri | Text | -commcare_location_id | Text | Foreign key into the commcare_locations table -commcare_location_ids | Text | -commcare_primary_case_sharing_id | Text | -commcare_project | Text | -username | Text | +| Column | Type | Note | +|----------------------------------|------|-------------------------------------| +| id | Text | Primary key | +| default_phone_number | Text | | +| email | Text | | +| first_name | Text | | +| groups | Text | | +| last_name | Text | | +| phone_numbers | Text | | +| resource_uri | Text | | +| commcare_location_id | Text | Foreign key to `commcare_locations` | +| commcare_location_ids | Text | | +| commcare_primary_case_sharing_id | Text | | +| commcare_project | Text | | +| username | Text | | The data in the 'commcare_users' table comes from the [List Mobile Workers API endpoint](https://confluence.dimagi.com/display/commcarepublic/List+Mobile+Workers). @@ -156,28 +212,28 @@ API endpoint](https://confluence.dimagi.com/display/commcarepublic/List+Mobile+W Specifying the --locations option or --with-organization options will export an additional table named 'commcare_locations' containing the following columns: -Column | Type | Note ------- | ---- | ---- -id | Text | -created_at | Date | -domain | Text | -external_id | Text | -last_modified | Date | -latitude | Text | -location_data | Text | -location_id | Text | Primary key -location_type | Text | -longitude | Text | -name | Text | -parent | Text | Resource URI of parent location -resource_uri | Text | -site_code | Text | -location_type_administrative | Text | -location_type_code | Text | -location_type_name | Text | -location_type_parent | Text | -*location level code* | Text | Column name depends on project's organization -*location level code* | Text | Column name depends on project's organization +| Column | Type | Note | +|------------------------------|------|-----------------------------------------------| +| id | Text | | +| created_at | Date | | +| domain | Text | | +| external_id | Text | | +| last_modified | Date | | +| latitude | Text | | +| location_data | Text | | +| location_id | Text | Primary key | +| location_type | Text | | +| longitude | Text | | +| name | Text | | +| parent | Text | Resource URI of parent location | +| resource_uri | Text | | +| site_code | Text | | +| location_type_administrative | Text | | +| location_type_code | Text | | +| location_type_name | Text | | +| location_type_parent | Text | | +| *location level code* | Text | Column name depends on project's organization | +| *location level code* | Text | Column name depends on project's organization | The data in the 'commcare_locations' table comes from the Location API endpoint along with some additional columns from the Location Type API @@ -189,17 +245,17 @@ location at that level of your organization. Consider the example organization from the [CommCare help page](https://confluence.dimagi.com/display/commcarepublic/Setting+up+Organization+Levels+and+Structure). A piece of the 'commcare_locations' table could look like this: -location_id | location_type_name | chw | supervisor | clinic | district ------------ | ------------------ | ------ | ---------- | ------ | -------- -939fa8 | District | NULL | NULL | NULL | 939fa8 -c4cbef | Clinic | NULL | NULL | c4cbef | 939fa8 -a9ca40 | Supervisor | NULL | a9ca40 | c4cbef | 939fa8 -4545b9 | CHW | 4545b9 | a9ca40 | c4cbef | 939fa8 +| location_id | location_type_name | chw | supervisor | clinic | district | +|-------------|--------------------|--------|------------|--------|----------| +| 939fa8 | District | NULL | NULL | NULL | 939fa8 | +| c4cbef | Clinic | NULL | NULL | c4cbef | 939fa8 | +| a9ca40 | Supervisor | NULL | a9ca40 | c4cbef | 939fa8 | +| 4545b9 | CHW | 4545b9 | a9ca40 | c4cbef | 939fa8 | In order to join form or case data to 'commcare_users' and 'commcare_locations' the exported forms and cases need to contain a field identifying which user submitted them. The --with-organization option automatically adds a field -called 'commcare_userid' to each query in an Excel specifiction for this +called 'commcare_userid' to each query in an Excel specification for this purpose. Using that field, you can use a SQL query with a join to report data about any level of you organization. For example, to count the number of forms submitted by all workers in each clinic: @@ -228,35 +284,93 @@ you will change the columns of the 'commcare_locations' table and it is very likely you will want to drop the table before exporting with the new organization. +Scheduling the DET +------------------ +Scheduling the DET to run at regular intervals is a useful tactic to keep your +database up to date with CommCare HQ. + +A common approach to scheduling DET runs is making use of the operating systems' scheduling +libraries to invoke a script to execute the `commcare-export` command. Sample scripts can be +found in the `examples/` directory for both Windows and Linux. + +### Windows +On Windows systems you can make use of the [task scheduler](https://sqlbackupandftp.com/blog/how-to-schedule-a-script-via-windows-task-scheduler/) +to run scheduled scripts for you. + +The `examples/` directory contains a sample script file, `scheduled_run_windows.bat`, which can be used by the +task scheduler to invoke the `commcare-export` command. + +To set up the scheduled task you can follow the steps below. +1. Copy the file `scheduled_run_windows.bat` to any desired location on your system (e.g. `Documents`) +2. Edit the copied `.bat` file and populate your own details +3. Follow the steps outlined [here](https://sqlbackupandftp.com/blog/how-to-schedule-a-script-via-windows-task-scheduler/), +using the .bat file when prompted for the `Program/script`. + + +### Linux +On a Linux system you can make use of the [crontab](https://www.techtarget.com/searchdatacenter/definition/crontab) +command to create scheduled actions (cron jobs) in the system. + +The `examples/` directory contains a sample script file, `scheduled_run_linux.sh`, which can be used by the cron job. +To set up the cron job you can follow the steps below. +1. Copy the example file to the home directory +> cp ./examples/scheduled_run_linux.sh ~/scheduled_run_linux.sh +2. Edit the file to populate your own details +> nano ~/scheduled_run_linux.sh +3. Create a cron job by appending to the crontab file +> crontab -e + +Make an entry below any existing cron jobs. The example below executes the script file at the top of +every 12th hour of every day +> 0 12 * * * bash ~/scheduled_run_linux.sh + +You can consult the [crontab.guru](https://crontab.guru/) tool which is very useful to generate and interpret +any custom cron schedules. + Python Library Usage -------------------- As a library, the various `commcare_export` modules make it easy to - - Interact with the CommCareHQ REST API + - Interact with the CommCare HQ REST API - Execute "Minilinq" queries against the API (a very simple query language, described below) - Load and save JSON representations of Minilinq queries - Compile Excel configurations to Minilinq queries -To directly access the CommCareHq REST API: +To directly access the CommCare HQ REST API: ```python ->>> import getpass ->>> from commcare_export.commcare_hq_client import CommCareHqClient ->>> api_client = CommCareHqClient('http://commcarehq.org', 'your_project', 'your_username', getpass.getpass()) ->>> forms = api_client.iterate('form', {'app_id': "whatever"}) ->>> [ (form['received_on'], form['form.gender']) for form in forms ] +from commcare_export.checkpoint import CheckpointManagerWithDetails +from commcare_export.commcare_hq_client import CommCareHqClient, AUTH_MODE_APIKEY +from commcare_export.commcare_minilinq import get_paginator, PaginationMode + +username = 'some@username.com' +domain = 'your-awesome-domain' +hq_host = 'https://commcarehq.org' +API_KEY= 'your_secret_api_key' + +api_client = CommCareHqClient(hq_host, domain, username, API_KEY, AUTH_MODE_APIKEY) +case_paginator=get_paginator(resource='case', pagination_mode=PaginationMode.date_modified) +case_paginator.init() +checkpoint_manager=CheckpointManagerWithDetails(None, None, PaginationMode.date_modified) + +cases = api_client.iterate('case', case_paginator, checkpoint_manager=checkpoint_manager) + +for case in cases: + print(case['case_id']) + ``` To issue a `minilinq` query against it, and then print out that query in a JSON serialization: ```python -import getpass import json +import sys from commcare_export.minilinq import * from commcare_export.commcare_hq_client import CommCareHqClient from commcare_export.commcare_minilinq import CommCareHqEnv -from commcare_export.env import BuiltInEnv +from commcare_export.env import BuiltInEnv, JsonPathEnv +from commcare_export.writers import StreamingMarkdownTableWriter api_client = CommCareHqClient( url="http://www.commcarehq.org", @@ -287,20 +401,19 @@ query = Emit( source ) -print json.dumps(query.to_jvalue(), indent=2) +print(json.dumps(query.to_jvalue(), indent=2)) results = query.eval(BuiltInEnv() | CommCareHqEnv(api_client) | JsonPathEnv()) if len(list(env.emitted_tables())) > 0: - # with writers.Excel2007TableWriter("excel-output.xlsx") as writer: - with writers.StreamingMarkdownTableWriter(sys.stdout) as writer: + with StreamingMarkdownTableWriter(sys.stdout) as writer: for table in env.emitted_tables(): writer.write_table(table) ``` Which will output JSON equivalent to this: -```javascript +```json { "Emit": { "headings": [ @@ -323,7 +436,7 @@ Which will output JSON equivalent to this: } ] }, - "name": None, + "name": null, "source": { "Apply": { "args": [ @@ -381,19 +494,23 @@ List of builtin functions: | Function | Description | Example Usage | |--------------------------------|--------------------------------------------------------------------------------|----------------------------------| | `+, -, *, //, /, >, <, >=, <=` | Standard Math | | -| len | Length | | -| bool | Bool | | -| str2bool | Convert string to boolean. True values are 'true', 't', '1' (case insensitive) | | -| str2date | Convert string to date | | -| bool2int | Convert boolean to integer (0, 1) | | -| str2num | Parse string as a number | | -| substr | Returns substring indexed by [first arg, second arg), zero-indexed. | substr(2, 5) of 'abcdef' = 'cde' | -| selected-at | Returns the Nth word in a string. N is zero-indexed. | selected-at(3) - return 4th word | -| selected | Returns True if the given word is in the value. | selected(fever) | -| count-selected | Count the number of words | | -| json2str | Convert a JSON object to a string | -| template | Render a string template (not robust) | template({} on {}, state, date) | -| attachment_url | Convert an attachment name into it's download URL | | +| len | Length | | +| bool | Bool | | +| str2bool | Convert string to boolean. True values are 'true', 't', '1' (case insensitive) | | +| str2date | Convert string to date | | +| bool2int | Convert boolean to integer (0, 1) | | +| str2num | Parse string as a number | | +| format-uuid | Parse a hex UUID, and format it into hyphen-separated groups | | +| substr | Returns substring indexed by [first arg, second arg), zero-indexed. | substr(2, 5) of 'abcdef' = 'cde' | +| selected-at | Returns the Nth word in a string. N is zero-indexed. | selected-at(3) - return 4th word | +| selected | Returns True if the given word is in the value. | selected(fever) | +| count-selected | Count the number of words | | +| json2str | Convert a JSON object to a string | | +| template | Render a string template (not robust) | template({} on {}, state, date) | +| attachment_url | Convert an attachment name into it's download URL | | +| form_url | Output the URL to the form view on CommCare HQ | | +| case_url | Output the URL to the case view on CommCare HQ | | +| unique | Ouptut only unique values in a list | | Output Formats -------------- @@ -422,30 +539,41 @@ Required dependencies will be automatically installed via pip. But since you may not care about all export formats, the various dependencies there are optional. Here is how you might install them: -``` +```shell # To export "xlsx" -$ pip install openpyxl +$ pip install "commcare-export[xlsx]" # To export "xls" -$ pip install xlwt +$ pip install "commcare-export[xls]" + +# To sync with a Postgres database +$ pip install "commcare-export[postgres]" + +# To sync with a mysql database +$ pip install "commcare-export[mysql]" -# To sync with a SQL database -$ pip install SQLAlchemy alembic psycopg2 pymysql pyodbc +# To sync with a database which uses odbc (e.g. mssql) +$ pip install "commcare-export[odbc]" + +# To sync with another SQL database supported by SQLAlchemy +$ pip install "commcare-export[base_sql]" +# Then install the Python package for your database ``` Contributing ------------ -0\. Sign up for github, if you have not already, at https://github.com. +0\. Sign up for GitHub, if you have not already, at https://github.com. 1\. Fork the repository at https://github.com/dimagi/commcare-export. -2\. Clone your fork, install into a `virtualenv`, and start a feature branch +2\. Clone your fork, install into a virtualenv, and start a feature branch -``` -$ mkvirtualenv commcare-export +```shell $ git clone git@github.com:dimagi/commcare-export.git $ cd commcare-export +$ python3 -m venv venv +$ source venv/bin/activate $ pip install -e ".[test]" $ git checkout -b my-super-duper-feature ``` @@ -454,7 +582,7 @@ $ git checkout -b my-super-duper-feature 4\. Make sure the tests pass. The best way to test for all versions is to sign up for https://travis-ci.org and turn on automatic continuous testing for your fork. -``` +```shell $ py.test =============== test session starts =============== platform darwin -- Python 2.7.3 -- pytest-2.3.4 @@ -469,38 +597,44 @@ tests/test_writers.py ... ============ 17 passed in 2.09 seconds ============ ``` -5\. Push the feature branch up +5\. Type hints are used in the `env` and `minilinq` modules. Check that any changes in those modules adhere to those types: +```shell +$ mypy --install-types @mypy_typed_modules.txt ``` + +6\. Push the feature branch up + +```shell $ git push -u origin my-super-duper-feature ``` -6\. Visit https://github.com/dimagi/commcare-export and submit a pull request. +7\. Visit https://github.com/dimagi/commcare-export and submit a pull request. -7\. Accept our gratitude for contributing: Thanks! +8\. Accept our gratitude for contributing: Thanks! Release process --------------- 1\. Create a tag for the release -``` +```shell $ git tag -a "X.YY.0" -m "Release X.YY.0" $ git push --tags ``` 2\. Create the source distribution -``` +```shell $ python setup.py sdist ``` Ensure that the archive (`dist/commcare-export-X.YY.0.tar.gz`) has the correct version number (matching the tag name). 3\. Upload to pypi -``` +```shell $ pip install twine -$ twine upload dist/commcare-export-X.YY.0.tar.gz +$ twine upload -u dimagi dist/commcare-export-X.YY.0.tar.gz ``` 4\. Verify upload @@ -511,34 +645,40 @@ https://pypi.python.org/pypi/commcare-export https://github.com/dimagi/commcare-export/releases +Once the release is published a GitHub workflow is kicked off that compiles executables of the DET compatible with +Linux and Windows machines, adding it to the release as assets. + +[For Linux-based users] If you decide to download and use the executable file, please make sure the file has the executable permission enabled, +after which it can be invoked like any other executable though the command line. + Testing and Test Databases -------------------------- The following command will run the entire test suite (requires DB environment variables to be set as per below): -``` +```shell $ py.test ``` To run an individual test class or method you can run, e.g.: -``` +```shell $ py.test -k "TestExcelQuery" $ py.test -k "test_get_queries_from_excel" ``` To exclude the database tests you can run: -``` +```shell $ py.test -m "not dbtest" ``` When running database tests, supported databases are PostgreSQL, MySQL, MSSQL. To run tests against selected databases can be done using test marks as follows: -``` -py.test -m [postgres,mysql,mssql] +```shell +$ py.test -m [postgres,mysql,mssql] ``` Database URLs can be overridden via environment variables: @@ -550,14 +690,17 @@ MSSQL_URL=mssql+pyodbc://user:password@host/ Postgresql ========== -``` +```shell $ docker pull postgres:9.6 -$ docker run --name ccexport-postgres -p 5432:5432 -d postgres:9.6 +$ docker run --name ccexport-postgres -p 5432:5432 -e POSTGRES_PASSWORD=postgres -d postgres:9.6 +$ export POSTGRES_URL=postgresql://postgres:postgres@localhost/ ``` +[Docker postgres image docs](https://hub.docker.com/_/postgres/) + MySQL ===== -``` +```shell $ docker pull mysql $ docker run --name ccexport-mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=pw -e MYSQL_USER=travis -e MYSQL_PASSWORD='' -d mysql @@ -569,8 +712,8 @@ mysql> GRANT ALL PRIVILEGES ON *.* TO 'travis'@'%'; MSSQL ===== -``` -$ docker pull microsoft/mssql-server-linux:2017-latest +```shell +$ docker pull mcr.microsoft.com/mssql/server:2017-latest $ docker run -e "ACCEPT_EULA=Y" -e "MSSQL_SA_PASSWORD=Password@123" -p 1433:1433 --name mssql1 -d microsoft/mssql-server-linux:2017-latest # install driver @@ -584,8 +727,8 @@ $ odbcinst -q -d MSSQL for Mac OS ========== -``` -$ docker pull microsoft/mssql-server-linux:2017-latest +```shell +$ docker pull mcr.microsoft.com/mssql/server:2017-latest $ docker run -e "ACCEPT_EULA=Y" -e "MSSQL_SA_PASSWORD=Password@123" -p 1433:1433 --name mssql1 -d microsoft/mssql-server-linux:2017-latest # Install driver @@ -599,7 +742,7 @@ Setup=/usr/local/lib/libtdsodbc.so UsageCount=1 # Create a soft link from /etc/odbcinst.ini to actual file -sudo ln -s /usr/local/etc/odbcinst.ini /etc/odbcinst.ini +$ sudo ln -s /usr/local/etc/odbcinst.ini /etc/odbcinst.ini ``` @@ -611,9 +754,9 @@ access to the corpora domain. These need to be set as environment variables as follows: -``` -export HQ_USERNAME= -export HQ_API_KEY= +```shell +$ export HQ_USERNAME= +$ export HQ_API_KEY= ``` For Travis builds these are included as encrypted vars in the travis diff --git a/build_exe/README.md b/build_exe/README.md new file mode 100644 index 00000000..28997f34 --- /dev/null +++ b/build_exe/README.md @@ -0,0 +1,18 @@ +# Compiling DET to running executable +This folder contains relevant files needed for compiling the DET into an executable file. +The executable is generated on after every release of the DET and the resultant files are uploaded +to the release as assets. + +## Testing locally +In the event that you want to test the exe compilation locally you can follow the steps below: + +Install `pyinstaller`: +> python -m pip install pyinstaller + +Now create the executable (assuming you're running this on a Linux machine): +> pyinstaller --dist ./dist/linux commcare-export.spec + +The resultant executable file can be located under `./dist/linux/`. + +The argument, `commcare-export.spec`, is a simple configuration file used by +pyinstaller which you ideally shouldn't have to ever change. \ No newline at end of file diff --git a/build_exe/requirements.txt b/build_exe/requirements.txt new file mode 100644 index 00000000..5231b26e --- /dev/null +++ b/build_exe/requirements.txt @@ -0,0 +1,8 @@ +# This file is only used by pyinstaller to create the executable DET instance +chardet +psycopg2-binary +pymysql +pyodbc +urllib3==1.26.7 +xlwt +openpyxl diff --git a/build_exe/runtime_hook.py b/build_exe/runtime_hook.py new file mode 100644 index 00000000..d226247a --- /dev/null +++ b/build_exe/runtime_hook.py @@ -0,0 +1,4 @@ +import os + +# This env variable is used to alter bundled behaviour +os.environ['DET_EXECUTABLE'] = '1' diff --git a/commcare-export.spec b/commcare-export.spec new file mode 100644 index 00000000..428f7f40 --- /dev/null +++ b/commcare-export.spec @@ -0,0 +1,40 @@ +# -*- mode: python ; coding: utf-8 -*- + + +a = Analysis( + ['commcare_export/cli.py'], + pathex=[], + binaries=[], + datas=[ + ('./commcare_export', './commcare_export'), + ('./migrations', './migrations'), + ], + hiddenimports=[ + 'sqlalchemy.sql.default_comparator', + ], + hookspath=[], + runtime_hooks=['build_exe/runtime_hook.py'], + excludes=[], +) +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.datas, + [], + name='commcare-export', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) diff --git a/commcare_export/__init__.py b/commcare_export/__init__.py index 58f3ace6..cd015d74 100644 --- a/commcare_export/__init__.py +++ b/commcare_export/__init__.py @@ -1 +1,39 @@ +import logging +import os +import re from .version import __version__ + +repo_root = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) + + +class Logger: + def __init__(self, logger, level): + self.logger = logger + self.level = level + self.linebuf = '' + + def write(self, buf): + for line in buf.rstrip().splitlines(): + self.logger.log(self.level, line.rstrip()) + + +def logger_name_from_filepath(filepath): + relative_path = os.path.relpath(filepath, start=repo_root) + cleaned_path = relative_path.replace('/', '.') + return re.sub(r'\.py$', '', cleaned_path) + + +def get_error_logger(): + return Logger(logging.getLogger(), logging.ERROR) + + +def get_logger(filepath=None): + if filepath: + logger = logging.getLogger( + logger_name_from_filepath(filepath) + ) + else: + logger = logging.getLogger() + + logger.setLevel(logging.DEBUG) + return logger diff --git a/commcare_export/builtin_queries.py b/commcare_export/builtin_queries.py index da2f2acd..b1167ed8 100644 --- a/commcare_export/builtin_queries.py +++ b/commcare_export/builtin_queries.py @@ -11,6 +11,7 @@ class Column: + def __init__(self, name, source, map_function=None, *extra_args): self.name = Literal(name) self.source = source @@ -22,22 +23,31 @@ def mapped_source_field(self): if not self.map_function: return Reference(self.source) else: - return Apply(Reference(self.map_function), Reference(self.source), - *self.extra_args) + return Apply( + Reference(self.map_function), + Reference(self.source), + *self.extra_args + ) def compile_query(columns, data_source, table_name): - source = Apply(Reference('api_data'), Literal(data_source), - Reference('checkpoint_manager')) - part = excel_query.SheetParts(table_name, [c.name for c in columns], source, - List([c.mapped_source_field for c in columns]), - None) + source = Apply( + Reference('api_data'), + Literal(data_source), + Reference('checkpoint_manager') + ) + part = excel_query.SheetParts( + table_name, [c.name for c in columns], + source, + List([c.mapped_source_field for c in columns]), + None + ) return excel_query.compile_queries([part], None, False)[0] -# A MiniLinq query for internal CommCare user table. -# It reads every field produced by the /user/ API endpoint and -# writes the data to a table named "commcare_users" in a database. +# A MiniLinq query for internal CommCare user table. It reads every +# field produced by the /user/ API endpoint and writes the data to a +# table named "commcare_users" in a database. user_columns = [ Column('id', 'id'), @@ -50,26 +60,29 @@ def compile_query(columns, data_source, table_name): Column('resource_uri', 'resource_uri'), Column('commcare_location_id', 'user_data.commcare_location_id'), Column('commcare_location_ids', 'user_data.commcare_location_ids'), - Column('commcare_primary_case_sharing_id', - 'user_data.commcare_primary_case_sharing_id'), + Column( + 'commcare_primary_case_sharing_id', + 'user_data.commcare_primary_case_sharing_id' + ), Column('commcare_project', 'user_data.commcare_project'), Column('username', 'username') ] users_query = compile_query(user_columns, 'user', USERS_TABLE_NAME) +# A MiniLinq query for internal CommCare location table. It reads every +# field produced by the /location/ API endpoint and appends fields to +# hold parent locations using location_type information before writing +# the data to a table named "commcare_locations" in a database. -# A MiniLinq query for internal CommCare location table. -# It reads every field produced by the /location/ API endpoint and -# appends fields to hold parent locations using location_type information -# before writing the data to a table named "commcare_locations" in a database. def get_locations_query(lp): location_types = lp.location_types - # For test stability and clarity, we order location types from deepest - # to shallowest. + # For test stability and clarity, we order location types from + # deepest to shallowest. depth = {} + def set_depth(lt): if lt not in depth: parent = location_types[lt]['parent'] @@ -82,12 +95,14 @@ def set_depth(lt): for lt in location_types: set_depth(lt) - ordered_location_types = sorted(location_types.values(), - key=lambda lt: -depth[lt['resource_uri']]) + ordered_location_types = sorted( + location_types.values(), key=lambda lt: -depth[lt['resource_uri']] + ) location_codes = [lt['code'] for lt in ordered_location_types] # The input names are codes produced by Django's slugify utility - # method. Replace hyphens with underscores to be easier to use in SQL. + # method. Replace hyphens with underscores to be easier to use in + # SQL. def sql_column_name(code): return re.sub('-', '_', code) @@ -106,24 +121,47 @@ def sql_column_name(code): Column('parent', 'parent'), Column('resource_uri', 'resource_uri'), Column('site_code', 'site_code'), - Column('location_type_administrative', 'location_type', - 'get_location_info', Literal('administrative')), - Column('location_type_code', 'location_type', - 'get_location_info', Literal('code')), - Column('location_type_name', 'location_type', - 'get_location_info', Literal('name')), - Column('location_type_parent', 'location_type', - 'get_location_info', Literal('parent')), - ] + [Column(sql_column_name(code), - 'resource_uri', 'get_location_ancestor', - Literal(code)) for code in location_codes] - return compile_query(location_columns, 'location', - LOCATIONS_TABLE_NAME) + Column( + 'location_type_administrative', + 'location_type', + 'get_location_info', + Literal('administrative') + ), + Column( + 'location_type_code', + 'location_type', + 'get_location_info', + Literal('code') + ), + Column( + 'location_type_name', + 'location_type', + 'get_location_info', + Literal('name') + ), + Column( + 'location_type_parent', + 'location_type', + 'get_location_info', + Literal('parent') + ), + ] + [ + Column( + sql_column_name(code), + 'resource_uri', + 'get_location_ancestor', + Literal(code) + ) for code in location_codes + ] + return compile_query(location_columns, 'location', LOCATIONS_TABLE_NAME) + # Require specified columns in emitted tables. class ColumnEnforcer(): - columns_to_require = {'form': Column('commcare_userid', '$.metadata.userID'), - 'case': Column('commcare_userid', '$.user_id')} + columns_to_require = { + 'form': Column('commcare_userid', '$.metadata.userID'), + 'case': Column('commcare_userid', '$.user_id') + } def __init__(self): self._emitted_tables = set([]) @@ -133,4 +171,3 @@ def column_to_require(self, data_source): return ColumnEnforcer.columns_to_require[data_source] else: return None - diff --git a/commcare_export/checkpoint.py b/commcare_export/checkpoint.py index 6d56afe8..f28e6266 100644 --- a/commcare_export/checkpoint.py +++ b/commcare_export/checkpoint.py @@ -1,25 +1,20 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import datetime -import logging -import uuid - import os +import uuid from contextlib import contextmanager from operator import attrgetter import dateutil.parser -import six -from sqlalchemy import Column, String, Boolean, func, and_ +from sqlalchemy import Boolean, Column, String, and_, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker +from commcare_export.commcare_minilinq import PaginationMode from commcare_export.exceptions import DataExportException from commcare_export.writers import SqlMixin +from commcare_export import get_logger, repo_root -logger = logging.getLogger(__name__) -repo_root = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) - +logger = get_logger(__file__) Base = declarative_base() @@ -36,6 +31,21 @@ class Checkpoint(Base): since_param = Column(String) time_of_run = Column(String) final = Column(Boolean) + data_source = Column(String) + last_doc_id = Column(String) + pagination_mode = Column(String) + cursor = Column(String) + + def get_pagination_mode(self): + """ + Get Enum from value stored in the checkpoint. Null or empty + value defaults to 'date_modified' mode to support legacy + checkpoints. + """ + if not self.pagination_mode: + return PaginationMode.date_modified + + return PaginationMode[self.pagination_mode] def __repr__(self): return ( @@ -49,13 +59,19 @@ def __repr__(self): "commcare={r.commcare}, " "since_param={r.since_param}, " "time_of_run={r.time_of_run}, " - "final={r.final})>".format(r=self) - ) + "final={r.final}), " + "data_source={r.data_source}, " + "last_doc_id={r.last_doc_id}, " + "pagination_mode={r.pagination_mode}," + "cursor={r.cursor}>" + ).format(r=self) @contextmanager def session_scope(Session): - """Provide a transactional scope around a series of operations.""" + """ + Provide a transactional scope around a series of operations. + """ session = Session() try: yield session @@ -71,8 +87,22 @@ class CheckpointManager(SqlMixin): table_name = 'commcare_export_runs' migrations_repository = os.path.join(repo_root, 'migrations') - def __init__(self, db_url, query, query_md5, project, commcare, key=None, table_names=None, poolclass=None, engine=None): - super(CheckpointManager, self).__init__(db_url, poolclass=poolclass, engine=engine) + def __init__( + self, + db_url, + query, + query_md5, + project, + commcare, + key=None, + table_names=None, + poolclass=None, + engine=None, + data_source=None + ): + super(CheckpointManager, self).__init__( + db_url, poolclass=poolclass, engine=engine + ) self.query = query self.query_md5 = query_md5 self.project = project @@ -80,30 +110,66 @@ def __init__(self, db_url, query, query_md5, project, commcare, key=None, table_ self.key = key self.Session = sessionmaker(self.engine, expire_on_commit=False) self.table_names = table_names + self.data_source = data_source - def for_tables(self, table_names): + def for_dataset(self, data_source, table_names): return CheckpointManager( - self.db_url, self.query, self.query_md5, self.project, self.commcare, self.key, - engine=self.engine, table_names=table_names + self.db_url, + self.query, + self.query_md5, + self.project, + self.commcare, + self.key, + engine=self.engine, + table_names=table_names, + data_source=data_source ) - def set_checkpoint(self, checkpoint_time, is_final=False): - self._set_checkpoint(checkpoint_time, is_final) + def set_checkpoint( + self, + checkpoint_time, + pagination_mode, + is_final=False, + doc_id=None, + cursor=None, + ): + self._set_checkpoint( + checkpoint_time, + pagination_mode, + is_final, + doc_id=doc_id, + cursor=cursor, + ) if is_final: self._cleanup() - def _set_checkpoint(self, checkpoint_time, final, time_of_run=None): + def _set_checkpoint( + self, + checkpoint_time, + pagination_mode, + final, + time_of_run=None, + doc_id=None, + cursor=None, + ): logger.info( - 'Setting %s checkpoint for tables %s: %s', + 'Setting %s checkpoint: data_source: %s, tables: %s, ' + 'pagination_mode: %s, checkpoint: %s:%s', + # 'final' if final else 'batch', + self.data_source, ', '.join(self.table_names), - checkpoint_time + pagination_mode.name, + checkpoint_time, + doc_id, ) if not checkpoint_time: - raise DataExportException('Tried to set an empty checkpoint. This is not allowed.') + raise DataExportException( + 'Tried to set an empty checkpoint. This is not allowed.' + ) self._validate_tables() - if isinstance(checkpoint_time, six.text_type): + if isinstance(checkpoint_time, str): since_param = checkpoint_time else: since_param = checkpoint_time.isoformat() @@ -120,8 +186,13 @@ def _set_checkpoint(self, checkpoint_time, final, time_of_run=None): project=self.project, commcare=self.commcare, since_param=since_param, - time_of_run=time_of_run or datetime.datetime.utcnow().isoformat(), - final=final + time_of_run=time_of_run + or datetime.datetime.utcnow().isoformat(), + final=final, + data_source=self.data_source, + last_doc_id=doc_id, + pagination_mode=pagination_mode.name, + cursor=cursor, ) session.add(checkpoint) created.append(checkpoint) @@ -129,7 +200,9 @@ def _set_checkpoint(self, checkpoint_time, final, time_of_run=None): def create_checkpoint_table(self, revision='head'): from alembic import command, config - cfg = config.Config(os.path.join(self.migrations_repository, 'alembic.ini')) + cfg = config.Config( + os.path.join(self.migrations_repository, 'alembic.ini') + ) cfg.set_main_option('script_location', self.migrations_repository) with self.engine.begin() as connection: cfg.attributes['connection'] = connection @@ -139,62 +212,96 @@ def _cleanup(self): self._validate_tables() with session_scope(self.Session) as session: session.query(Checkpoint).filter_by( - final=False, query_file_md5=self.query_md5, - project=self.project, commcare=self.commcare - ).filter(Checkpoint.table_name.in_(self.table_names)).delete(synchronize_session='fetch') + final=False, + query_file_md5=self.query_md5, + project=self.project, + commcare=self.commcare + ).filter(Checkpoint.table_name.in_(self.table_names) + ).delete(synchronize_session='fetch') def get_time_of_last_checkpoint(self, log_warnings=True): - """Return the earliest time from the list of checkpoints that for the current - query file / key.""" + """ + Return the earliest time from the list of checkpoints that for + the current query file / key. + """ run = self.get_last_checkpoint() if run and log_warnings: self.log_warnings(run) return run.since_param if run else None def get_last_checkpoint(self): - """Return a single checkpoint such that it has the earliest `since_param` of all - checkpoints for the active tables.""" + """ + Return a single checkpoint such that it has the earliest + `since_param` of all checkpoints for the active tables. + """ self._validate_tables() table_runs = [] with session_scope(self.Session) as session: for table in self.table_names: if self.key: table_run = self._get_last_checkpoint( - session, table_name=table, - key=self.key, project=self.project, commcare=self.commcare + session, + table_name=table, + key=self.key, + project=self.project, + commcare=self.commcare ) else: table_run = self._get_last_checkpoint( - session, table_name=table, - query_file_md5=self.query_md5, project=self.project, commcare=self.commcare, key=self.key + session, + table_name=table, + query_file_md5=self.query_md5, + project=self.project, + commcare=self.commcare, + key=self.key ) if table_run: - table_runs.append(table_run) + table_runs.append(table_run) if not table_runs: table_runs = self.get_legacy_checkpoints() if table_runs: - sorted_runs = list(sorted(table_runs, key=attrgetter('time_of_run'))) + sorted_runs = list( + sorted(table_runs, key=attrgetter('time_of_run')) + ) return sorted_runs[0] def get_legacy_checkpoints(self): with session_scope(self.Session) as session: # check without table_name table_run = self._get_last_checkpoint( - session, query_file_md5=self.query_md5, table_name=None, - project=self.project, commcare=self.commcare, key=self.key + session, + query_file_md5=self.query_md5, + table_name=None, + project=self.project, + commcare=self.commcare, + key=self.key ) if table_run: - return self._set_checkpoint(table_run.since_param, table_run.final, table_run.time_of_run) + return self._set_checkpoint( + table_run.since_param, + PaginationMode.date_modified, + table_run.final, + table_run.time_of_run + ) # Check for run without the args table_run = self._get_last_checkpoint( - session, query_file_md5=self.query_md5, key=self.key, - project=None, commcare=None, table_name=None + session, + query_file_md5=self.query_md5, + key=self.key, + project=None, + commcare=None, + table_name=None ) if table_run: - return self._set_checkpoint(table_run.since_param, table_run.final, table_run.time_of_run) + return self._set_checkpoint( + table_run.since_param, + PaginationMode.date_modified, + table_run.final, + table_run.time_of_run + ) def _get_last_checkpoint(self, session, **kwarg_filters): query = session.query(Checkpoint) @@ -202,8 +309,7 @@ def _get_last_checkpoint(self, session, **kwarg_filters): query = query.filter_by(**kwarg_filters) return query.order_by(Checkpoint.time_of_run.desc()).first() - def log_warnings(self, run): - # type: (Checkpoint) -> None + def log_warnings(self, run: Checkpoint) -> None: md5_mismatch = run.query_file_md5 != self.query_md5 name_mismatch = run.query_file_name != self.query if md5_mismatch or name_mismatch: @@ -211,12 +317,16 @@ def log_warnings(self, run): "Query differs from most recent checkpoint:\n" "From checkpoint: name=%s, md5=%s\n" "From command line args: name=%s, md5=%s\n", - run.query_file_name, run.query_file_md5, - self.query, self.query_md5 + # + run.query_file_name, + run.query_file_md5, + self.query, + self.query_md5 ) def list_checkpoints(self, limit=20): - """List all checkpoints filtered by: + """ + List all checkpoints filtered by: * file name * project * commcare @@ -240,24 +350,36 @@ def _filter_query(self, query): return query def get_latest_checkpoints(self): - """Returns the latest checkpoint for each table filtered by the fields set in the manager: + """ + Returns the latest checkpoint for each table filtered by the + fields set in the manager: * query_md5 * project * commcare * key """ with session_scope(self.Session) as session: - cols = [Checkpoint.project, Checkpoint.commcare, Checkpoint.query_file_md5, Checkpoint.table_name] + cols = [ + Checkpoint.project, + Checkpoint.commcare, + Checkpoint.query_file_md5, + Checkpoint.table_name + ] inner_query = self._filter_query( session.query( - *(cols + [func.max(Checkpoint.time_of_run).label('max_time_of_run')]) - ) - .filter(Checkpoint.query_file_md5 == self.query_md5) - .filter(Checkpoint.table_name.isnot(None)) + *( + cols + [ + func.max(Checkpoint.time_of_run + ).label('max_time_of_run') + ] + ) + ).filter(Checkpoint.query_file_md5 == self.query_md5 + ).filter(Checkpoint.table_name.isnot(None)) ).group_by(*cols).subquery() query = session.query(Checkpoint).join( - inner_query, and_( + inner_query, + and_( Checkpoint.project == inner_query.c.project, Checkpoint.commcare == inner_query.c.commcare, Checkpoint.query_file_md5 == inner_query.c.query_file_md5, @@ -270,15 +392,23 @@ def get_latest_checkpoints(self): # Keeping for future reference # # window_func = func.row_number().over( - # partition_by=Checkpoint.table_name, order_by=Checkpoint.time_of_run.desc() + # partition_by=Checkpoint.table_name, + # order_by=Checkpoint.time_of_run.desc() # ).label("row_number") - # inner_query = self._filter_query(session.query(Checkpoint, window_func)) - # inner_query = inner_query.filter(Checkpoint.query_file_md5 == self.query_md5) - # inner_query = inner_query.filter(Checkpoint.table_name.isnot(None)).subquery() + # inner_query = self._filter_query( + # session.query(Checkpoint, window_func) + # ) + # inner_query = inner_query.filter( + # Checkpoint.query_file_md5 == self.query_md5 + # ) + # inner_query = inner_query.filter( + # Checkpoint.table_name.isnot(None) + # ).subquery() # # query = session.query(Checkpoint).select_entity_from(inner_query)\ # .filter(inner_query.c.row_number == 1)\ # .order_by(Checkpoint.table_name.asc()) + # return list(query) def update_checkpoint(self, run): @@ -290,18 +420,28 @@ def _validate_tables(self): raise Exception("Not tables set in checkpoint manager") -class CheckpointManagerWithSince(object): - def __init__(self, manager, since): +class CheckpointManagerWithDetails(object): + + def __init__(self, manager, since_param, pagination_mode): self.manager = manager - self.since_param = since + self.since_param = since_param + self.pagination_mode = pagination_mode - def set_checkpoint(self, checkpoint_time, is_final=False): + def set_checkpoint(self, checkpoint_time, is_final=False, doc_id=None, cursor=None): if self.manager: - self.manager.set_checkpoint(checkpoint_time, is_final) + self.manager.set_checkpoint( + checkpoint_time, self.pagination_mode, is_final, doc_id=doc_id, cursor=cursor + ) class CheckpointManagerProvider(object): - def __init__(self, base_checkpoint_manager=None, since=None, start_over=None): + + def __init__( + self, + base_checkpoint_manager=None, + since=None, + start_over=None, + ): self.start_over = start_over self.since = since self.base_checkpoint_manager = base_checkpoint_manager @@ -314,24 +454,70 @@ def get_since(self, checkpoint_manager): return self.since if checkpoint_manager: + if checkpoint_manager.data_source == 'ucr': + last_checkpoint = checkpoint_manager.get_last_checkpoint() + return last_checkpoint.cursor if last_checkpoint else None + since = checkpoint_manager.get_time_of_last_checkpoint() return dateutil.parser.parse(since) if since else None - def get_checkpoint_manager(self, table_names): - """This get's called before each table is exported and set in the `env`. It is then - passed to the API client and used to set the checkpoints. + def get_pagination_mode(self, data_source, checkpoint_manager=None): + """ + Always use the default pagination mode unless we are continuing + from a previous checkpoint in which case use the same pagination + mode as before. + """ + if self.start_over or self.since or not checkpoint_manager: + return self.get_paginator_for_datasource(data_source) + + last_checkpoint = checkpoint_manager.get_last_checkpoint() + if not last_checkpoint: + return self.get_paginator_for_datasource(data_source) + + return last_checkpoint.get_pagination_mode() - :param table_names: List of table names being exported to. This is a list since - multiple tables can be processed by a since API query. + @staticmethod + def get_paginator_for_datasource(datasource): + if datasource == 'ucr': + return PaginationMode.cursor + return PaginationMode.date_indexed + + def get_checkpoint_manager(self, data_source, table_names): + """ + This get's called before each table is exported and set in the + `env`. It is then passed to the API client and used to set the + checkpoints. + + :param data_source: Data source for this checkout e.g. 'form' + :param table_names: List of table names being exported to. This + is a list since multiple tables can be processed by a + 'since' API query. """ manager = None if self.base_checkpoint_manager: - manager = self.base_checkpoint_manager.for_tables(table_names) + manager = self.base_checkpoint_manager.for_dataset( + data_source, table_names + ) since = self.get_since(manager) - + pagination_mode = self.get_pagination_mode(data_source, checkpoint_manager=manager) logger.info( - "Creating checkpoint manager for tables: %s with 'since' parameter: %s", - ', '.join(table_names), since + "Creating checkpoint manager for tables: %s, since: %s, " + "pagination_mode: %s", + # + ', '.join(table_names), + since, + pagination_mode.name, ) - return CheckpointManagerWithSince(manager, since) + if pagination_mode not in PaginationMode.supported_modes(): + logger.warning( + "\n====================================\n" + "This export is using a deprecated pagination mode which will " + "be removed in\n" + "future versions. To switch to the new mode you must re-sync " + "your data using\n" + "`--start-over`. For more details see: %s" + "\n====================================\n", # + "https://github.com/dimagi/commcare-export/releases/tag/1.5.0" + ) + return CheckpointManagerWithDetails(manager, since, pagination_mode) diff --git a/commcare_export/cli.py b/commcare_export/cli.py index c65c8cf0..e3a7b1af 100644 --- a/commcare_export/cli.py +++ b/commcare_export/cli.py @@ -1,36 +1,37 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import argparse import getpass import io import json -import logging import os.path import sys - +import logging import dateutil.parser import requests import sqlalchemy -from six.moves import input -from commcare_export import excel_query -from commcare_export import writers +from commcare_export import builtin_queries, excel_query, writers from commcare_export.checkpoint import CheckpointManagerProvider -from commcare_export.misc import default_to_json -from commcare_export.utils import get_checkpoint_manager -from commcare_export.commcare_hq_client import CommCareHqClient, LATEST_KNOWN_VERSION, ResourceRepeatException +from commcare_export.commcare_hq_client import ( + LATEST_KNOWN_VERSION, + CommCareHqClient, + ResourceRepeatException, +) from commcare_export.commcare_minilinq import CommCareHqEnv -from commcare_export.env import BuiltInEnv, JsonPathEnv, EmitterEnv -from commcare_export.exceptions import LongFieldsException, DataExportException, MissingQueryFileException -from commcare_export.minilinq import MiniLinq, List +from commcare_export.env import BuiltInEnv, EmitterEnv, JsonPathEnv +from commcare_export.exceptions import ( + DataExportException, + MissingQueryFileException, +) +from commcare_export.location_info_provider import LocationInfoProvider +from commcare_export.minilinq import List, MiniLinq +from commcare_export.misc import default_to_json from commcare_export.repeatable_iterator import RepeatableIterator +from commcare_export.utils import get_checkpoint_manager from commcare_export.version import __version__ -from commcare_export import builtin_queries -from commcare_export.location_info_provider import LocationInfoProvider +from commcare_export import get_logger, get_error_logger EXIT_STATUS_ERROR = 1 - -logger = logging.getLogger(__name__) +logger = get_logger(__file__) commcare_hq_aliases = { 'local': 'http://localhost:8000', @@ -39,6 +40,7 @@ class Argument(object): + def __init__(self, name, *args, **kwargs): self.name = name.replace('-', '_') self._args = ['--{}'.format(name)] + list(args) @@ -54,68 +56,163 @@ def add_to_parser(self, parser, **additional_kwargs): CLI_ARGS = [ - Argument('version', default=False, action='store_true', - help='Print the current version of the commcare-export tool.'), - Argument('query', required=False, help='JSON or Excel query file'), - Argument('dump-query', default=False, action='store_true'), - Argument('commcare-hq', default='prod', - help='Base url for the CommCare HQ instance e.g. https://www.commcarehq.org'), - Argument('api-version', default=LATEST_KNOWN_VERSION), - Argument('project'), - Argument('username'), - Argument('password', help='Enter password, or if using apikey auth-mode, enter the api key.'), - Argument('auth-mode', default='password', choices=['password', 'apikey'], - help='Use "digest" auth, or "apikey" auth (for two factor enabled domains).'), - Argument('since', help='Export all data after this date. Format YYYY-MM-DD or YYYY-MM-DDTHH:mm:SS'), - Argument('until', help='Export all data up until this date. Format YYYY-MM-DD or YYYY-MM-DDTHH:mm:SS'), - Argument('start-over', default=False, action='store_true', - help='When saving to a SQL database; the default is to pick up since the last success. This disables that.'), - Argument('profile'), - Argument('verbose', default=False, action='store_true'), - Argument('output-format', default='json', choices=['json', 'csv', 'xls', 'xlsx', 'sql', 'markdown'], - help='Output format'), - Argument('output', metavar='PATH', default='reports.zip', help='Path to output; defaults to `reports.zip`.'), - Argument('strict-types', default=False, action='store_true', - help="When saving to a SQL database don't allow changing column types once they are created."), - Argument('missing-value', default=None, help="Value to use when a field is missing from the form / case."), - Argument('batch-size', default=200, help="Number of records to process per batch."), - Argument('checkpoint-key', help="Use this key for all checkpoints instead of the query file MD5 hash " - "in order to prevent table rebuilds after a query file has been edited."), - Argument('users', default=False, action='store_true', - help="Export a table containing data about this project's " - "mobile workers"), - Argument('locations', default=False, action='store_true', - help="Export a table containing data about this project's " - "locations"), - Argument('with-organization', default=False, action='store_true', - help="Export tables containing mobile worker data and " - "location data and add a commcare_userid field to any " - "exported form or case"), - ] + Argument( + 'version', + default=False, + action='store_true', + help='Print the current version of the commcare-export tool.' + ), + Argument('query', required=False, help='JSON or Excel query file'), + Argument('dump-query', default=False, action='store_true'), + Argument( + 'commcare-hq', + default='prod', + help='Base url for the CommCare HQ instance e.g. ' + 'https://www.commcarehq.org' + ), + Argument('api-version', default=LATEST_KNOWN_VERSION), + Argument('project'), + Argument('username'), + Argument( + 'password', + help='Enter password, or if using apikey auth-mode, enter the api key.' + ), + Argument( + 'auth-mode', + default='password', + choices=['password', 'apikey'], + help='Use "digest" auth, or "apikey" auth (for two factor enabled ' + 'domains).' + ), + Argument( + 'since', + help='Export all data after this date. Format YYYY-MM-DD or ' + 'YYYY-MM-DDTHH:mm:SS' + ), + Argument( + 'until', + help='Export all data up until this date. Format YYYY-MM-DD or ' + 'YYYY-MM-DDTHH:mm:SS' + ), + Argument( + 'start-over', + default=False, + action='store_true', + help='When saving to a SQL database; the default is to pick up ' + 'since the last success. This disables that.' + ), + Argument('profile'), + Argument('verbose', default=False, action='store_true'), + Argument( + 'output-format', + default='json', + choices=['json', 'csv', 'xls', 'xlsx', 'sql', 'markdown'], + help='Output format' + ), + Argument( + 'output', + metavar='PATH', + default='reports.zip', + help='Path to output; defaults to `reports.zip`.' + ), + Argument( + 'strict-types', + default=False, + action='store_true', + help="When saving to a SQL database don't allow changing column types " + "once they are created." + ), + Argument( + 'missing-value', + default=None, + help="Value to use when a field is missing from the form / case." + ), + Argument( + 'batch-size', + default=200, + help="Number of records to process per batch." + ), + Argument( + 'checkpoint-key', + help="Use this key for all checkpoints instead of the query file MD5 " + "hash in order to prevent table rebuilds after a query file has " + "been edited." + ), + Argument( + 'users', + default=False, + action='store_true', + help="Export a table containing data about this project's mobile " + "workers" + ), + Argument( + 'locations', + default=False, + action='store_true', + help="Export a table containing data about this project's locations" + ), + Argument( + 'with-organization', + default=False, + action='store_true', + help="Export tables containing mobile worker data and location data " + "and add a commcare_userid field to any exported form or case" + ), + Argument( + 'export-root-if-no-subdocument', + default=False, + action='store_true', + help="Use this when you are exporting a nested document e.g. " + "form.form..case, messaging-event.messages.[*] And you want to " + "have a record exported even if the nested document does not " + "exist or is empty.", + ), + Argument( + 'no-logfile', + default=False, + help="Specify in order to prevent information being logged to the log file and" + " show all output in the console.", + action='store_true', + ), +] def main(argv): - parser = argparse.ArgumentParser('commcare-export', 'Output a customized export of CommCareHQ data.') + parser = argparse.ArgumentParser( + 'commcare-export', 'Output a customized export of CommCareHQ data.' + ) for arg in CLI_ARGS: arg.add_to_parser(parser) - try: - args = parser.parse_args(argv) - except UnicodeDecodeError: - for arg in argv: - try: - arg.encode('utf-8') - except UnicodeDecodeError: - print(u"ERROR: Argument '%s' contains unicode characters. " - u"Only ASCII characters are supported.\n" % unicode(arg, 'utf-8'), file=sys.stderr) - sys.exit(1) + args = parser.parse_args(argv) + + if args.output_format and args.output: + errors = [] + errors.extend(validate_output_filename(args.output_format, args.output)) + if errors: + raise Exception(f"Could not proceed. Following issues were found: {', '.join(errors)}.") + + if not args.no_logfile: + exe_dir = os.path.dirname(sys.executable) + log_file = os.path.join(exe_dir, "commcare_export.log") + print(f"Printing logs to {log_file}") + logging.basicConfig( + filename=log_file, + format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', + filemode='w', + ) + sys.stderr = get_error_logger() if args.verbose: - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s') + logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s' + ) else: - logging.basicConfig(level=logging.WARN, - format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s') + logging.basicConfig( + level=logging.WARN, + format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s' + ) logging.getLogger('alembic').setLevel(logging.WARN) logging.getLogger('backoff').setLevel(logging.FATAL) @@ -123,11 +220,18 @@ def main(argv): if args.version: print('commcare-export version {}'.format(__version__)) - exit(0) + sys.exit(0) if not args.project: - print('commcare-export: error: argument --project is required', file=sys.stderr) - exit(1) + error_msg = "commcare-export: error: argument --project is required" + # output to log file through sys.stderr + print( + error_msg, + file=sys.stderr + ) + # Output to console for debugging + print(error_msg) + sys.exit(1) if args.profile: # hotshot is gone in Python 3 @@ -137,8 +241,17 @@ def main(argv): profile.start() try: - exit(main_with_args(args)) + print("Running export...") + try: + exit_code = main_with_args(args) + if exit_code > 0: + print("Error occurred! See log file for error.") + sys.exit(exit_code) + except Exception: + print("Error occurred! See log file for error.") + raise finally: + print("Export finished!") if args.profile: profile.close() stats = hotshot.stats.load(args.profile) @@ -147,6 +260,23 @@ def main(argv): stats.print_stats(100) +def validate_output_filename(output_format, output_filename): + """ + Validate file extensions for csv, xls and xlsx output formats. + Ensure extension unless using sql output_format. + """ + errors = [] + if output_format == 'csv' and not output_filename.endswith('.zip'): + errors.append("For output format as csv, output file name should have extension zip") + elif output_format == 'xls' and not output_filename.endswith('.xls'): + errors.append("For output format as xls, output file name should have extension xls") + elif output_format == 'xlsx' and not output_filename.endswith('.xlsx'): + errors.append("For output format as xlsx, output file name should have extension xlsx") + elif output_format != 'sql' and "." not in output_filename: + errors.append("Missing extension in output file name") + return errors + + def _get_query(args, writer, column_enforcer=None): return _get_query_from_file( args.query, @@ -154,23 +284,38 @@ def _get_query(args, writer, column_enforcer=None): writer.supports_multi_table_write, writer.max_column_length, writer.required_columns, - column_enforcer + column_enforcer, + args.export_root_if_no_subdocument ) -def _get_query_from_file(query_arg, missing_value, combine_emits, - max_column_length, required_columns, column_enforcer): + +def _get_query_from_file( + query_arg, + missing_value, + combine_emits, + max_column_length, + required_columns, + column_enforcer, + value_or_root +): if os.path.exists(query_arg): if os.path.splitext(query_arg)[1] in ['.xls', '.xlsx']: import openpyxl workbook = openpyxl.load_workbook(query_arg) return excel_query.get_queries_from_excel( - workbook, missing_value, combine_emits, - max_column_length, required_columns, column_enforcer + workbook, + missing_value, + combine_emits, + max_column_length, + required_columns, + column_enforcer, + value_or_root ) else: with io.open(query_arg, encoding='utf-8') as fh: return MiniLinq.from_jvalue(json.loads(fh.read())) + def get_queries(args, writer, lp, column_enforcer=None): query_list = [] if args.query is not None: @@ -198,17 +343,30 @@ def _get_writer(output_format, output, strict_types): return writers.Excel2003TableWriter(output) elif output_format == 'csv': if not output.endswith(".zip"): - print("WARNING: csv output is a zip file, but " - "will be written to %s" % output) - print("Consider appending .zip to the file name to avoid confusion.") + print( + "WARNING: csv output is a zip file, but " + "will be written to %s" % output + ) + print( + "Consider appending .zip to the file name to avoid confusion." + ) return writers.CsvTableWriter(output) elif output_format == 'json': return writers.JValueTableWriter() elif output_format == 'markdown': return writers.StreamingMarkdownTableWriter(sys.stdout) elif output_format == 'sql': - # Output should be a connection URL - # Writer had bizarre issues so we use a full connection instead of passing in a URL or engine + # Output should be a connection URL. Writer had bizarre issues + # so we use a full connection instead of passing in a URL or + # engine. + if output.startswith('mysql'): + charset_split = output.split('charset=') + if len(charset_split) > 1 and charset_split[1] != 'utf8mb4': + raise Exception( + f"The charset '{charset_split[1]}' might cause problems with the export. " + f"It is recommended that you use 'utf8mb4' instead." + ) + return writers.SqlTableWriter(output, strict_types) else: raise Exception("Unknown output format: {}".format(output_format)) @@ -232,11 +390,17 @@ def _get_api_client(args, commcarehq_base_url): def _get_checkpoint_manager(args): - if not args.users and not args.locations and not os.path.exists(args.query): - logger.warning("Checkpointing disabled for non builtin, " - "non file-based query") + if not args.users and not args.locations and not os.path.exists( + args.query + ): + logger.warning( + "Checkpointing disabled for non builtin, " + "non file-based query" + ) elif args.since or args.until: - logger.warning("Checkpointing disabled when using '--since' or '--until'") + logger.warning( + "Checkpointing disabled when using '--since' or '--until'" + ) else: checkpoint_manager = get_checkpoint_manager(args) checkpoint_manager.create_checkpoint_table() @@ -259,8 +423,11 @@ def evaluate_query(env, query): force_lazy_result(lazy_result) return 0 except requests.exceptions.RequestException as e: - if e.response.status_code == 401: - print("\nAuthentication failed. Please check your credentials.", file=sys.stderr) + if e.response and e.response.status_code == 401: + print( + "\nAuthentication failed. Please check your credentials.", + file=sys.stderr + ) return EXIT_STATUS_ERROR else: raise @@ -269,8 +436,11 @@ def evaluate_query(env, query): print(e.message) print('Try increasing --batch-size to overcome the error') return EXIT_STATUS_ERROR - except (sqlalchemy.exc.DataError, sqlalchemy.exc.InternalError, - sqlalchemy.exc.ProgrammingError) as e: + except ( + sqlalchemy.exc.DataError, + sqlalchemy.exc.InternalError, + sqlalchemy.exc.ProgrammingError + ) as e: print('Stopping because of database error:\n', e) return EXIT_STATUS_ERROR except KeyboardInterrupt: @@ -283,15 +453,29 @@ def main_with_args(args): writer = _get_writer(args.output_format, args.output, args.strict_types) if args.query is None and args.users is False and args.locations is False: - print('At least one the following arguments is required: ' - '--query, --users, --locations', file=sys.stderr) + print( + 'At least one the following arguments is required: ' + '--query, --users, --locations', + file=sys.stderr + ) return EXIT_STATUS_ERROR + if not args.username: + logger.warn("Username not provided") + args.username = input('Please provide a username: ') + + if not args.password: + logger.warn("Password not provided") + # Windows getpass does not accept unicode + args.password = getpass.getpass() + column_enforcer = None if args.with_organization: column_enforcer = builtin_queries.ColumnEnforcer() - commcarehq_base_url = commcare_hq_aliases.get(args.commcare_hq, args.commcare_hq) + commcarehq_base_url = commcare_hq_aliases.get( + args.commcare_hq, args.commcare_hq + ) api_client = _get_api_client(args, commcarehq_base_url) lp = LocationInfoProvider(api_client, page_size=args.batch_size) try: @@ -308,17 +492,13 @@ def main_with_args(args): if writer.support_checkpoints: checkpoint_manager = _get_checkpoint_manager(args) - if not args.username: - args.username = input('Please provide a username: ') - - if not args.password: - # Windows getpass does not accept unicode - args.password = getpass.getpass() - since, until = get_date_params(args) if args.start_over: if checkpoint_manager: - logger.warning('Ignoring all checkpoints and re-fetching all data from CommCare.') + logger.warning( + 'Ignoring all checkpoints and re-fetching all data from ' + 'CommCare.' + ) elif since: logger.debug('Starting from %s', args.since) @@ -330,16 +510,22 @@ def main_with_args(args): 'get_location_ancestor': lp.get_location_ancestor } env = ( - BuiltInEnv(static_env) - | CommCareHqEnv(api_client, until=until, page_size=args.batch_size) - | JsonPathEnv({}) - | EmitterEnv(writer) + BuiltInEnv(static_env) + | CommCareHqEnv(api_client, until=until, page_size=args.batch_size) + | JsonPathEnv({}) + | EmitterEnv(writer) ) exit_status = evaluate_query(env, query) if args.output_format == 'json': - print(json.dumps(list(writer.tables.values()), indent=4, default=default_to_json)) + print( + json.dumps( + list(writer.tables.values()), + indent=4, + default=default_to_json + ) + ) return exit_status diff --git a/commcare_export/commcare_hq_client.py b/commcare_export/commcare_hq_client.py index 0761532b..84cb83d8 100644 --- a/commcare_export/commcare_hq_client.py +++ b/commcare_export/commcare_hq_client.py @@ -1,31 +1,41 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes +from __future__ import ( + absolute_import, + division, + generators, + nested_scopes, + print_function, + unicode_literals, +) import copy import logging +import sys from collections import OrderedDict +from math import ceil +from urllib.parse import urlencode import backoff import requests -from requests.auth import AuthBase -from requests.auth import HTTPDigestAuth +from requests.auth import AuthBase, HTTPDigestAuth + +import commcare_export +from commcare_export.repeatable_iterator import RepeatableIterator +from commcare_export import get_logger AUTH_MODE_PASSWORD = 'password' AUTH_MODE_APIKEY = 'apikey' -try: - from urllib.request import urlopen - from urllib.parse import urlparse, urlencode, parse_qs -except ImportError: - from urlparse import urlparse, parse_qs - from urllib import urlopen, urlencode -import commcare_export -from commcare_export.repeatable_iterator import RepeatableIterator +LATEST_KNOWN_VERSION = '0.5' +RESOURCE_REPEAT_LIMIT = 10 + +logger = get_logger(__file__) -logger = logging.getLogger(__name__) -LATEST_KNOWN_VERSION='0.5' -RESOURCE_REPEAT_LIMIT=10 +def on_wait(details): + time_to_wait = details["wait"] + logger.warning(f"Rate limit reached. Waiting for {time_to_wait} seconds.") + def on_backoff(details): _log_backoff(details, 'Waiting for retry.') @@ -37,7 +47,10 @@ def on_giveup(details): def _log_backoff(details, action_message): details['__suffix'] = action_message - logger.warning("Request failed after {tries} attempts ({elapsed:.1f}s). {__suffix}".format(**details)) + logger.warning( + "Request failed after {tries} attempts ({elapsed:.1f}s). {__suffix}" + .format(**details) + ) def is_client_error(ex): @@ -51,6 +64,7 @@ def is_client_error(ex): class ResourceRepeatException(Exception): + def __init__(self, message): self.message = message @@ -63,15 +77,23 @@ class CommCareHqClient(object): A connection to CommCareHQ for a particular version, project, and user. """ - def __init__(self, url, project, username, password, - auth_mode=AUTH_MODE_PASSWORD, version=LATEST_KNOWN_VERSION, checkpoint_manager=None): + def __init__( + self, + url, + project, + username, + password, + auth_mode=AUTH_MODE_PASSWORD, + version=LATEST_KNOWN_VERSION, + ): self.version = version self.url = url self.project = project self.__auth = self._get_auth(username, password, auth_mode) self.__session = None - def _get_auth(self, username, password, mode): + @staticmethod + def _get_auth(username, password, mode): if mode == AUTH_MODE_PASSWORD: return HTTPDigestAuth(username, password) elif mode == AUTH_MODE_APIKEY: @@ -84,7 +106,7 @@ def session(self): if self.__session == None: self.__session = requests.Session() self.__session.headers.update({ - 'User-Agent': 'commcare-export/%s' % commcare_export.__version__ + 'User-Agent': f'commcare-export/{commcare_export.__version__}' }) return self.__session @@ -97,35 +119,82 @@ def session(self, session): def api_url(self): return '%s/a/%s/api/v%s' % (self.url, self.project, self.version) - @backoff.on_exception( - backoff.expo, requests.exceptions.RequestException, - max_time=300, giveup=is_client_error, - on_backoff=on_backoff, on_giveup=on_giveup - ) + @staticmethod + def _should_raise_for_status(response): + return "Retry-After" not in response.headers + def get(self, resource, params=None): """ - Gets the named resource. + Gets the named resource. When the server returns a 429 (too many requests), the process will sleep for + the amount of seconds specified in the Retry-After header from the response, after which it will raise + an exception to trigger the retry action. - Currently a bit of a vulnerable stub that works - for this particular use case in the hands of a trusted user; would likely + Currently, a bit of a vulnerable stub that works for this + particular use case in the hands of a trusted user; would likely want this to work like (or via) slumber. """ - logger.debug("Fetching '%s' batch: %s", resource, params) - resource_url = '%s/%s/' % (self.api_url, resource) - response = self.session.get(resource_url, params=params, auth=self.__auth, timeout=60) - response.raise_for_status() + @backoff.on_predicate( + backoff.runtime, + predicate=lambda r: r.status_code == 429, + value=lambda r: ceil(float(r.headers.get("Retry-After", 1.0))), + jitter=None, + on_backoff=on_wait, + ) + @backoff.on_exception( + backoff.expo, + requests.exceptions.RequestException, + max_time=300, + giveup=is_client_error, + on_backoff=on_backoff, + on_giveup=on_giveup + ) + def _get(resource, params=None): + logger.debug("Fetching '%s' batch: %s", resource, params) + resource_url = f'{self.api_url}/{resource}/' + response = self.session.get( + resource_url, params=params, auth=self.__auth, timeout=60 + ) + if self._should_raise_for_status(response): + try: + response.raise_for_status() + except Exception as e: + # for non-verbose output, skip the stacktrace + if not logger.isEnabledFor(logging.DEBUG): + if isinstance(e, requests.exceptions.HTTPError) and response.status_code == 401: + logger.error( + f"#{e}. Please ensure that your CommCare HQ credentials are correct and auth-mode " + f"is passed as 'apikey' if using API Key to authenticate. Also, verify that your " + f"account has access to the project and the necessary permissions to use " + f"commcare-export." + ) + else: + logger.error(str(e)) + sys.exit() + raise e + + return response + + response = _get(resource, params) return response.json() - - def iterate(self, resource, paginator, params=None, checkpoint_manager=None): + + def iterate( + self, + resource, + paginator, + params=None, + checkpoint_manager=None, + ): """ Assumes the endpoint is a list endpoint, and iterates over it making a lot of assumptions that it is like a tastypie endpoint. """ + unknown_count = 'unknown' params = dict(params or {}) + def iterate_resource(resource=resource, params=params): more_to_fetch = True last_batch_ids = set() - total_count = None + total_count = unknown_count fetched = 0 repeat_counter = 0 last_params = None @@ -136,54 +205,67 @@ def iterate_resource(resource=resource, params=params): else: repeat_counter = 0 if repeat_counter >= RESOURCE_REPEAT_LIMIT: - raise ResourceRepeatException("Requested resource '{}' {} times with same parameters".format(resource, repeat_counter)) + raise ResourceRepeatException( + f"Requested resource '{resource}' {repeat_counter} " + "times with same parameters" + ) batch = self.get(resource, params) last_params = copy.copy(params) - if not total_count or total_count == 'unknown' or fetched >= total_count: - total_count = int(batch['meta']['total_count']) if batch['meta']['total_count'] else 'unknown' + batch_meta = batch['meta'] + if total_count == unknown_count or fetched >= total_count: + if batch_meta.get('total_count'): + total_count = int(batch_meta['total_count']) + else: + total_count = unknown_count fetched = 0 - fetched += len(batch['objects']) + batch_objects = batch['objects'] + fetched += len(batch_objects) logger.debug('Received %s of %s', fetched, total_count) - - if not batch['objects']: + if not batch_objects: more_to_fetch = False else: - for obj in batch['objects']: + got_new_data = False + for obj in batch_objects: if obj['id'] not in last_batch_ids: yield obj + got_new_data = True - if batch['meta']['next']: - last_batch_ids = {obj['id'] for obj in batch['objects']} + if batch_meta.get('next'): + last_batch_ids = {obj['id'] for obj in batch_objects} params = paginator.next_page_params_from_batch(batch) if not params: more_to_fetch = False else: more_to_fetch = False - self.checkpoint(checkpoint_manager, paginator, batch, not more_to_fetch) - - return RepeatableIterator(iterate_resource) + limit = batch_meta.get('limit') + if more_to_fetch: + # Handle the case where API is 'non-counting' + # and repeats the last batch + repeated_last_page_of_non_counting_resource = ( + not got_new_data and total_count == unknown_count + and (limit and len(batch_objects) < limit) + ) + more_to_fetch = not repeated_last_page_of_non_counting_resource + + paginator.set_checkpoint( + checkpoint_manager, + batch, + not more_to_fetch + ) - def checkpoint(self, checkpoint_manager, paginator, batch, is_final): - from commcare_export.commcare_minilinq import DatePaginator - if isinstance(paginator, DatePaginator): - since_date = paginator.get_since_date(batch) - if since_date: - checkpoint_manager.set_checkpoint(since_date, is_final) - else: - logger.warning('Failed to get a checkpoint date from a batch of data.') + return RepeatableIterator(iterate_resource) class MockCommCareHqClient(object): """ - An in-memory mock of the hq client, instantiated - with a simple mapping of resource and params to results. + An in-memory mock of the hq client, instantiated with a simple + mapping of resource and params to results. - Since dictionaries are not hashable, the mapping is - written as a pair of tuples, handled appropriately - internallly. + Since dictionaries are not hashable, the mapping is written as a + pair of tuples, handled appropriately internally. MockCommCareHqClient({ 'forms': [ @@ -195,28 +277,50 @@ class MockCommCareHqClient(object): ), ] }) - """ - def __init__(self, mock_data): - self.mock_data = dict([(resource, dict([(urlencode(OrderedDict(sorted(params.items()))), result) for params, result in resource_results])) - for resource, resource_results in mock_data.items()]) + """ - def iterate(self, resource, paginator, params=None, checkpoint_manager=None): - logger.debug('Mock client call to resource "%s" with params "%s"', resource, params) - return self.mock_data[resource][urlencode(OrderedDict(sorted(params.items())))] + def __init__(self, mock_data): + self.mock_data = { + resource: { + _params_to_url(params): result + for (params, result) in resource_results + } for (resource, resource_results) in mock_data.items() + } + + def iterate( + self, resource, paginator, params=None, checkpoint_manager=None + ): + logger.debug( + 'Mock client call to resource "%s" with params "%s"', + resource, + params + ) + return self.mock_data[resource][_params_to_url(params)] def get(self, resource): logger.debug('Mock client call to get resource "%s"', resource) - objects = self.mock_data[resource][urlencode(OrderedDict([('get', True)]))] + objects = self.mock_data[resource][_params_to_url({'get': True})] if objects: - return {'meta': {'limit': len(objects), 'next': None, - 'offset': 0, 'previous': None, - 'total_count': len(objects)}, - 'objects': objects} + return { + 'meta': { + 'limit': len(objects), + 'next': None, + 'offset': 0, + 'previous': None, + 'total_count': len(objects) + }, + 'objects': objects + } else: return None +def _params_to_url(params): + return urlencode(OrderedDict(sorted(params.items()))) + + class ApiKeyAuth(AuthBase): + def __init__(self, username, apikey): self.username = username self.apikey = apikey @@ -234,6 +338,5 @@ def __ne__(self, other): return not self == other def __call__(self, r): - r.headers['Authorization'] = 'apikey %s:%s' % (self.username, self.apikey) + r.headers['Authorization'] = f'apikey {self.username}:{self.apikey}' return r - diff --git a/commcare_export/commcare_minilinq.py b/commcare_export/commcare_minilinq.py index 78e8660e..b575e7f7 100644 --- a/commcare_export/commcare_minilinq.py +++ b/commcare_export/commcare_minilinq.py @@ -5,33 +5,63 @@ API directly. """ import json - -from commcare_export.env import DictEnv, CannotBind, CannotReplace +from enum import Enum +from urllib.parse import parse_qs, urlparse from datetime import datetime +from dateutil.parser import ParserError, parse + +from commcare_export.env import CannotBind, CannotReplace, DictEnv from commcare_export.misc import unwrap +from commcare_export import get_logger + +logger = get_logger(__file__) + +SUPPORTED_RESOURCES = { + 'form', + 'case', + 'user', + 'location', + 'application', + 'web-user', + 'messaging-event', + 'ucr', +} + +DEFAULT_PAGE_SIZE = 1000 +DEFAULT_UCR_PAGE_SIZE = 10000 + -try: - from urllib.parse import urlparse, parse_qs -except ImportError: - from urlparse import urlparse, parse_qs +class PaginationMode(Enum): + date_indexed = "date_indexed" + date_modified = "date_modified" + cursor = "cursor" + + @classmethod + def supported_modes(cls): + return [ + cls.date_indexed, + cls.cursor, + ] class SimpleSinceParams(object): + def __init__(self, start, end): self.start_param = start self.end_param = end def __call__(self, since, until): - params = { - self.start_param: since.isoformat() - } + params = {} + if since: + params[self.start_param] = since.isoformat() if until: params[self.end_param] = until.isoformat() return params class FormFilterSinceParams(object): + def __call__(self, since, until): range_expression = {} if since: @@ -40,59 +70,78 @@ def __call__(self, since, until): if until: range_expression['lte'] = until.isoformat() - server_modified_missing = {"missing": { - "field": "server_modified_on", "null_value": True, "existence": True} + server_modified_missing = { + "missing": { + "field": "server_modified_on", + "null_value": True, + "existence": True + } } query = json.dumps({ 'filter': { - "or": [ - { - "and": [ - { - "not": server_modified_missing - }, - { - "range": { - "server_modified_on": range_expression - } - } - ] - }, - { - "and": [ - server_modified_missing, - { - "range": { - "received_on": range_expression - } - } - ] - } - ] - }}) + "or": [{ + "and": [{ + "not": server_modified_missing + }, { + "range": { + "server_modified_on": range_expression + } + }] + }, + { + "and": [ + server_modified_missing, { + "range": { + "received_on": range_expression + } + } + ] + }] + } + }) return {'_search': query} -resource_since_params = { - 'form': FormFilterSinceParams(), - 'case': SimpleSinceParams('server_date_modified_start', 'server_date_modified_end'), - 'user': None, - 'location': None, - 'application': None, - 'web-user': None, +DATE_PARAMS = { + 'indexed_on': + SimpleSinceParams('indexed_on_start', 'indexed_on_end'), + 'server_date_modified': + SimpleSinceParams( + 'server_date_modified_start', 'server_date_modified_end' + ), # used by messaging-events + 'date_last_activity': + SimpleSinceParams('date_last_activity.gte', 'date_last_activity.lt'), } -def get_paginator(resource, page_size=1000): +def get_paginator( + resource, + page_size=None, + pagination_mode=PaginationMode.date_indexed, +): return { - 'form': DatePaginator('form', ['server_modified_on','received_on'], page_size), - 'case': DatePaginator('case', 'server_date_modified', page_size), - 'user': SimplePaginator('user', page_size), - 'location': SimplePaginator('location', page_size), - 'application': SimplePaginator('application', page_size), - 'web-user': SimplePaginator('web-user', page_size), - }[resource] + PaginationMode.date_indexed: { + 'form': DatePaginator('indexed_on', page_size), + 'case': DatePaginator('indexed_on', page_size), + 'messaging-event': DatePaginator('date_last_activity', page_size), + }, + PaginationMode.date_modified: { + 'form': + DatePaginator( + ['server_modified_on', 'received_on'], + page_size, + params=FormFilterSinceParams(), + ), + 'case': + DatePaginator('server_date_modified', page_size), + 'messaging-event': + DatePaginator('date_last_activity', page_size), + }, + PaginationMode.cursor: { + 'ucr': UCRPaginator(page_size), + }, + }[pagination_mode].get(resource, SimplePaginator(page_size)) class CommCareHqEnv(DictEnv): @@ -100,26 +149,36 @@ class CommCareHqEnv(DictEnv): An environment providing primitives for pulling from the CommCareHq API. """ - - def __init__(self, commcare_hq_client, until=None, page_size=1000): + + def __init__(self, commcare_hq_client, page_size=None, until=None): self.commcare_hq_client = commcare_hq_client self.until = until self.page_size = page_size - super(CommCareHqEnv, self).__init__({ - 'api_data' : self.api_data - }) + super(CommCareHqEnv, self).__init__({'api_data': self.api_data}) @unwrap('checkpoint_manager') - def api_data(self, resource, checkpoint_manager, payload=None, include_referenced_items=None): - if resource not in resource_since_params: - raise ValueError('I do not know how to access the API resource "%s"' % resource) - - paginator = get_paginator(resource, self.page_size) + def api_data( + self, + resource, + checkpoint_manager, + payload=None, + include_referenced_items=None + ): + if resource not in SUPPORTED_RESOURCES: + raise ValueError('Unknown API resource "%s' % resource) + + paginator = get_paginator( + resource, self.page_size, checkpoint_manager.pagination_mode + ) paginator.init(payload, include_referenced_items, self.until) - initial_params = paginator.next_page_params_since(checkpoint_manager.since_param) + initial_params = paginator.next_page_params_since( + checkpoint_manager.since_param + ) return self.commcare_hq_client.iterate( - resource, paginator, - params=initial_params, checkpoint_manager=checkpoint_manager + resource, + paginator, + params=initial_params, + checkpoint_manager=checkpoint_manager ) def bind(self, name, value): @@ -133,9 +192,11 @@ class SimplePaginator(object): """ Paginate based on the 'next' URL provided in the API response. """ - def __init__(self, resource, page_size=1000): - self.resource = resource + + def __init__(self, page_size=None, params=None): + page_size = page_size if page_size else 1000 self.page_size = page_size + self.params = params def init(self, payload=None, include_referenced_items=None, until=None): self.payload = dict(payload or {}) # Do not mutate passed-in dicts @@ -146,14 +207,14 @@ def next_page_params_since(self, since=None): params = self.payload params['limit'] = self.page_size - resource_date_params = resource_since_params[self.resource] - if (since or self.until) and resource_date_params: - params.update( - resource_date_params(since, self.until) - ) + if (since or self.until) and self.params: + params.update(self.params(since, self.until)) if self.include_referenced_items: - params.update([('%s__full' % referenced_item, 'true') for referenced_item in self.include_referenced_items]) + params.update([ + (f'{referenced_item}__full', 'true') + for referenced_item in self.include_referenced_items + ]) return params @@ -161,22 +222,33 @@ def next_page_params_from_batch(self, batch): if batch['meta']['next']: return parse_qs(urlparse(batch['meta']['next']).query) + def set_checkpoint(self, *args, **kwargs): + pass + class DatePaginator(SimplePaginator): """ - This paginator is designed to get around the issue of deep paging where the deeper the page the longer - the query takes. + This paginator is designed to get around the issue of deep paging + where the deeper the page the longer the query takes. - Paginate records according to a date in the record. The params for the next batch will include a filter - for the date of the last record in the previous batch. + Paginate records according to a date in the record. The params for + the next batch will include a filter for the date of the last record + in the previous batch. - This also adds an ordering parameter to ensure that the records are ordered by the date field in ascending order. + This also adds an ordering parameter to ensure that the records are + ordered by the date field in ascending order. - :param resource: The name of the resource being fetched: ``form`` or ``case``. :param since_field: The name of the date field to use for pagination. + :param page_size: Number of results to request in each page """ - def __init__(self, resource, since_field, page_size=1000): - super(DatePaginator, self).__init__(resource, page_size) + + DEFAULT_PARAMS = object() + + def __init__(self, since_field, page_size=None, params=DEFAULT_PARAMS): + page_size = page_size if page_size else DEFAULT_PAGE_SIZE + params = DATE_PARAMS[ + since_field] if params is DatePaginator.DEFAULT_PARAMS else params + super(DatePaginator, self).__init__(page_size, params) self.since_field = since_field def next_page_params_since(self, since=None): @@ -205,8 +277,52 @@ def get_since_date(self, batch): since = last_obj.get(self.since_field) if since: - for fmt in ('%Y-%m-%dT%H:%M:%SZ', '%Y-%m-%dT%H:%M:%S.%fZ'): - try: - return datetime.strptime(since, fmt) - except ValueError: - pass + try: + return parse( + since, + # ignoretz since we assume utc, and use naive + # datetimes everywhere + ignoretz=True + ) + except ParserError: + return None + + def set_checkpoint(self, checkpoint_manager, batch, is_final): + since_date = self.get_since_date(batch) + if since_date: + try: + last_obj = batch['objects'][-1] + except IndexError: + last_obj = {} + checkpoint_manager.set_checkpoint( + since_date, is_final, doc_id=last_obj.get("id", None) + ) + else: + logger.warning( + 'Failed to get a checkpoint date from a batch of data.' + ) + + +class UCRPaginator(SimplePaginator): + def __init__(self, page_size=None, *args, **kwargs): + super().__init__(page_size, *args, **kwargs) + self.page_size = page_size if page_size else DEFAULT_UCR_PAGE_SIZE + + def next_page_params_from_batch(self, batch): + params = super(UCRPaginator, self).next_page_params_from_batch(batch) + if params: + return self.payload | params + + def next_page_params_since(self, since=None): + params = self.payload + params['cursor'] = since + params["limit"] = self.page_size + return params + + def set_checkpoint(self, checkpoint_manager, batch, is_final): + cursor = self.next_page_params_from_batch(batch)['cursor'][0] + checkpoint_manager.set_checkpoint( + datetime.utcnow(), + is_final, + cursor=cursor, + ) diff --git a/commcare_export/data_types.py b/commcare_export/data_types.py index 3a15340c..8ffe3c8f 100644 --- a/commcare_export/data_types.py +++ b/commcare_export/data_types.py @@ -5,14 +5,18 @@ DATA_TYPE_DATE = 'date' DATA_TYPE_DATETIME = 'datetime' DATA_TYPE_INTEGER = 'integer' +DATA_TYPE_JSON = 'json' DATA_TYPES_TO_SQLALCHEMY_TYPES = { + DATA_TYPE_TEXT: sqlalchemy.Text(), DATA_TYPE_BOOLEAN: sqlalchemy.Boolean(), DATA_TYPE_DATETIME: sqlalchemy.DateTime(), DATA_TYPE_DATE: sqlalchemy.Date(), DATA_TYPE_INTEGER: sqlalchemy.Integer(), + DATA_TYPE_JSON: sqlalchemy.JSON(), } + class UnknownDataType(Exception): pass diff --git a/commcare_export/env.py b/commcare_export/env.py index 4f8d675c..2b2a5223 100644 --- a/commcare_export/env.py +++ b/commcare_export/env.py @@ -1,81 +1,84 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import hashlib import json -from datetime import datetime +import logging import operator +import sys +import uuid +from typing import Any, Dict, Union, overload + import pytz -import six -from itertools import chain -from jsonpath_rw import jsonpath -from jsonpath_rw.parser import parse as parse_jsonpath +from commcare_export.jsonpath_utils import split_leftmost from commcare_export.misc import unwrap, unwrap_val - from commcare_export.repeatable_iterator import RepeatableIterator +from jsonpath_ng import jsonpath +from jsonpath_ng.parser import parse as parse_jsonpath + +logger = logging.getLogger(__name__) JSONPATH_CACHE = {} -class CannotBind(Exception): pass -class CannotReplace(Exception): pass -class CannotEmit(Exception): pass -class NotFound(Exception): pass + +class CannotBind(Exception): + pass + + +class CannotReplace(Exception): + pass + + +class CannotEmit(Exception): + pass + + +class NotFound(Exception): + pass + class Env(object): """ An abstract model of an "environment" where data can be bound to names and later looked up. Not simply a dictionary as lookup in our case may support JsonPath, or may be a chaining of other - environments, so the abstract interface will - allow experimentation and customization. + environments, so the abstract interface will allow experimentation + and customization. """ # # Interface # - def bind(self, name, value): + def bind(self, name: str, value: Any) -> 'Env': """ - (key, ??) -> Env - - Returns a new environment that is equivalent - to the current except the provided key is - bound to the value passed in. If the environment - does not support such a binding, raises - CannotBind + Returns a new environment that is equivalent to the current + except the provided key is bound to the value passed in. If the + environment does not support such a binding, raises CannotBind """ raise NotImplementedError() - def lookup(self, key): + def lookup(self, key: str) -> Any: """ - key -> ?? - - Note that the ?? may be None which may mean - the value was unbound or may mean it was - found and was None. This may need revisiting. - This may also raise NotFound if it is the + Note that the return value may be ``None`` which may mean the + value was unbound or may mean it was found and was None. This + may need revisiting. This may also raise NotFound if it is the sort of environment that does that. """ raise NotImplementedError() - def replace(self, data): + def replace(self, data: Dict[str, Any]) -> 'Env': """ - data -> Env - - Completely replace the environment with new - data (somewhat like "this"-based Map functions a la jQuery). - Could be the same as creating a new empty env - and binding "@" in JsonPath. + Completely replace the environment with new data (somewhat like + "this"-based Map functions a la jQuery). Could be the same as + creating a new empty env and binding "@" in JsonPath. - May raise CannotReplace if this environment does - not support the input replacement + May raise CannotReplace if this environment does not support the + input replacement """ raise NotImplementedError() - # Minor impurity of the idea of a binding env: - # also allow `Emit` to directly call into - # the environment. It is up to the env - # whether to store it, write it immediately, - # or do something clever with iterators, etc. + # Minor impurity of the idea of a binding env: also allow `Emit` to + # directly call into the environment. It is up to the env whether to + # store it, write it immediately, or do something clever with + # iterators, etc. def emit_table(self, table_spec): raise CannotEmit() @@ -87,50 +90,62 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass - + # # Fluent interface to combinators # def __or__(self, other): return OrElse(self, other) + # # Combinators # + class OrElse(Env): """ - An environment that chains together a left environment - and a right environment. Note that this differes from - just a bunch of bindings, as the two envs might have - entirely different mechanisms (for example a magic - environment for special operators vs a JsonPathEnv - that always returns a list and operates only on - simple data) + An environment that chains together a left environment and a right + environment. Note that this differs from just a bunch of bindings, + as the two envs might have entirely different mechanisms (for + example a magic environment for special operators vs a JsonPathEnv + that always returns a list and operates only on simple data) """ + def __init__(self, left, right): self.left = left self.right = right - + def bind(self, name, value): - try: return OrElse(self.left.bind(name, value), self.right) - except CannotBind: return OrElse(self.left, self.right.bind(name, value)) + try: + return OrElse(self.left.bind(name, value), self.right) + except CannotBind: + return OrElse(self.left, self.right.bind(name, value)) def lookup(self, name): - try: return self.left.lookup(name) - except NotFound: return self.right.lookup(name) + try: + return self.left.lookup(name) + except NotFound: + return self.right.lookup(name) def replace(self, data): # A bit sketchy... - try: return OrElse(self.left.replace(data), self.right) - except CannotReplace: return OrElse(self.left, self.right.replace(data)) + try: + return OrElse(self.left.replace(data), self.right) + except CannotReplace: + return OrElse(self.left, self.right.replace(data)) def emit_table(self, table_spec): - try: return self.left.emit_table(table_spec) - except CannotEmit: return self.right.emit_table(table_spec) + try: + return self.left.emit_table(table_spec) + except CannotEmit: + return self.right.emit_table(table_spec) def has_emitted_tables(self): - return any([self.left.has_emitted_tables(), self.right.has_emitted_tables()]) + return any([ + self.left.has_emitted_tables(), + self.right.has_emitted_tables() + ]) def __enter__(self): self.left.__enter__() @@ -145,41 +160,49 @@ def __exit__(self, exc_type, exc_val, exc_tb): # # Concrete environment classes -# +# + class DictEnv(Env): """ A simple dictionary environment; more-or-less boring! """ + def __init__(self, d=None): self.d = d or {} def bind(self, name, value): return DictEnv(dict(list(self.d.items()) + [(name, value)])) - + def lookup(self, name): - try: return self.d[name] - except KeyError: raise NotFound(unwrap_val(name)) + try: + return self.d[name] + except KeyError: + raise NotFound(unwrap_val(name)) def replace(self, data): - if isinstance(data, dict): return DictEnv(data) - else: raise CannotReplace() + if isinstance(data, dict): + return DictEnv(data) + else: + raise CannotReplace() class JsonPathEnv(Env): """ - An environment like those that map names - to variables, but supporting dereferencing - an JsonPath expression. Note that it never - fails a lookup, but always returns an empty - list. + An environment like those that map names to variables, but + supporting dereferencing an JsonPath expression. Note that it never + fails a lookup, but always returns an empty list. It also interns all parsed expressions """ + def __init__(self, bindings=None): self.__bindings = bindings or {} - - # Currently hardcoded because it is a global is jsonpath-rw + self.__restrict_to_root = bool( + jsonpath.Fields("__root_only").find(self.__bindings) + ) + + # Currently hardcoded because it is a global is jsonpath-ng # Probably not widely used, but will require refactor if so jsonpath.auto_id_field = "id" @@ -187,34 +210,50 @@ def parse(self, jsonpath_string): if jsonpath_string not in JSONPATH_CACHE: JSONPATH_CACHE[jsonpath_string] = parse_jsonpath(jsonpath_string) return JSONPATH_CACHE[jsonpath_string] - - def lookup(self, name): - "str|JsonPath -> ??" - if isinstance(name, six.string_types): + + def lookup( + self, + name: Union[str, jsonpath.JSONPath] + ) -> RepeatableIterator: + if isinstance(name, str): jsonpath_expr = self.parse(name) elif isinstance(name, jsonpath.JSONPath): jsonpath_expr = name else: raise NotFound(unwrap_val(name)) - def iter(jsonpath_expr=jsonpath_expr): # Capture closure + # special case for 'id' + if self.__restrict_to_root and str(jsonpath_expr) != 'id': + expr, _ = split_leftmost(jsonpath_expr) + if not isinstance(expr, jsonpath.Root): + return RepeatableIterator(lambda: iter(())) + + def iterator(jsonpath_expr=jsonpath_expr): # Capture closure for datum in jsonpath_expr.find(self.__bindings): - # HACK: The auto id from jsonpath_rw is good, but we lose it when we do .value here, - # so just slap it on if not present + # HACK: The auto id from jsonpath_ng is good, but we + # lose it when we do .value here, so just slap it on if + # not present if isinstance(datum.value, dict) and 'id' not in datum.value: datum.value['id'] = jsonpath.AutoIdForDatum(datum).value yield datum - return RepeatableIterator(iter) + + return RepeatableIterator(iterator) + + @overload + def bind(self, key: str, value: Any, *args) -> Env: + ... + + @overload + def bind(self, bindings: Dict[str, Any], *args) -> Env: + ... def bind(self, *args): - "(str, ??) -> Env | ({str: ??}) -> Env" - new_bindings = dict(self.__bindings) if isinstance(args[0], dict): new_bindings.update(args[0]) return self.__class__(new_bindings) - - elif isinstance(args[0], six.string_types): + + elif isinstance(args[0], str): new_bindings[args[0]] = args[1] return self.__class__(new_bindings) @@ -237,8 +276,8 @@ def _not_val(val): def _to_unicode(val): if isinstance(val, bytes): return val.decode('utf8') - elif not isinstance(val, six.text_type): - return six.text_type(val) + elif not isinstance(val, str): + return str(val) return val @@ -290,6 +329,7 @@ def str2date(val): return date.replace(microsecond=0, tzinfo=None) + @unwrap('val') def bool2int(val): return int(str2bool(val)) @@ -301,7 +341,7 @@ def sha1(val): return None if not isinstance(val, bytes): - val = six.text_type(val).encode('utf8') + val = str(val).encode('utf8') return hashlib.sha1(val).hexdigest() @@ -347,7 +387,7 @@ def count_selected(val): @unwrap('val') def json2str(val): - if isinstance(val, six.string_types): + if isinstance(val, str): return val try: return json.dumps(val) @@ -355,10 +395,30 @@ def json2str(val): return +@unwrap('val') +def format_uuid(val): + """ + Renders a hex UUID in hyphen-separated groups + + >>> format_uuid('00a3e0194ce1458794c50971dee2de22') + '00a3e019-4ce1-4587-94c5-0971dee2de22' + >>> format_uuid(0x00a3e0194ce1458794c50971dee2de22) + '00a3e019-4ce1-4587-94c5-0971dee2de22' + """ + if not val: + return None + if isinstance(val, int): + val = hex(val) + try: + return str(uuid.UUID(val)) + except ValueError: + return None + + def join(*args): - args = [unwrap_val(arg)for arg in args] + args_ = [unwrap_val(arg) for arg in args] try: - return args[0].join(args[1:]) + return args_[0].join(args_[1:]) except TypeError: return '""' @@ -385,13 +445,48 @@ def attachment_url(val): ) +@unwrap('val') +def form_url(val): + return _doc_url('form_data') + + +@unwrap('val') +def case_url(val): + return _doc_url('case_data') + + +def _doc_url(url_path): + from commcare_export.minilinq import Apply, Reference, Literal + return Apply( + Reference('template'), + Literal('{}/a/{}/reports/' + url_path + '/{}/'), + Reference('commcarehq_base_url'), + Reference('$.domain'), + Reference('$.id'), + ) + + def template(format_template, *args): - args = [unwrap_val(arg) for arg in args] - return format_template.format(*args) + args_ = [unwrap_val(arg) for arg in args] + return format_template.format(*args_) def _or(*args): - unwrapped_args = (unwrap_val(arg) for arg in args) + return _or_impl(unwrap_val, *args) + + +def _or_raw(*args): + + def unwrap_iter(arg): + if isinstance(arg, RepeatableIterator): + return list(arg) + return arg + + return _or_impl(unwrap_iter, *args) + + +def _or_impl(_unwrap, *args): + unwrapped_args = (_unwrap(arg) for arg in args) vals = (val for val in unwrapped_args if val is not None and val != []) try: return next(vals) @@ -412,16 +507,23 @@ def substr(val, start, end): return val[start:end] +@unwrap('val') +def unique(val): + if isinstance(val, list): + return list(set(val)) + return val + + class BuiltInEnv(DictEnv): """ - A built-in environment of operators and functions - which does not support replacement or bindings. + A built-in environment of operators and functions which does not + support replacement or bindings. - For convenience, this environment has been chosen to - queue up tables to be written out, since it will be - the first env involved in almost any situation. + For convenience, this environment has been chosen to queue up tables + to be written out, since it will be the first env involved in almost + any situation. """ - + def __init__(self, d=None): self.__tables = [] d = d or {} @@ -442,25 +544,34 @@ def __init__(self, d=None): 'str2num': str2num, 'str2date': str2date, 'json2str': json2str, + 'format-uuid': format_uuid, 'selected': selected, 'selected-at': selected_at, 'count-selected': count_selected, 'join': join, 'default': default, 'template': template, + 'form_url': form_url, + 'case_url': case_url, 'attachment_url': attachment_url, 'filter_empty': _not_val, 'or': _or, 'sha1': sha1, 'substr': substr, + '_or_raw': _or_raw, # for internal use, + 'unique': unique }) - return super(BuiltInEnv, self).__init__(d) + super(BuiltInEnv, self).__init__(d) - def bind(self, name, value): raise CannotBind() - def replace(self, data): raise CannotReplace() + def bind(self, name, value): + raise CannotBind() + + def replace(self, data): + raise CannotReplace() class EmitterEnv(Env): + def __init__(self, writer): self.writer = writer self.emitted = False @@ -471,30 +582,56 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.writer.__exit__(exc_type, exc_val, exc_tb) - def bind(self, name, value): raise CannotBind() - def replace(self, data): raise CannotReplace() - def lookup(self, key): raise NotFound() + def bind(self, name, value): + raise CannotBind() + + def replace(self, data): + raise CannotReplace() + + def lookup(self, key): + raise NotFound() def emit_table(self, table_spec): self.emitted = True table_spec.rows = self._unwrap_row_vals(table_spec.rows) - self.writer.write_table(table_spec) + try: + self.writer.write_table(table_spec) + except Exception as err: + if ( + not logger.isEnabledFor(logging.DEBUG) # not --verbose + and 'Row size too large' in str(err) + ): + logging.error( + 'Row size too large. The amount of data required by rows ' + 'is more than this type of database table allows. One ' + 'way to resolve this error is to reduce the number of ' + 'columns that you are exporting. A general guideline is ' + 'not to exceed 200 columns.' + ) + sys.exit(1) + else: + raise def has_emitted_tables(self): return self.emitted @staticmethod def _unwrap_row_vals(rows): - """The XMLtoJSON conversion in CommCare can result in a field being a JSON object - instead of a simple field (if the XML tag has attributes or different namespace from - the default). In this case the actual value of the XML element is stored in a '#text' field. """ + The XMLtoJSON conversion in CommCare can result in a field being + a JSON object instead of a simple field (if the XML tag has + attributes or different namespace from the default). In this + case the actual value of the XML element is stored in a '#text' + field. + """ + def _unwrap_val(val): if isinstance(val, dict): if '#text' in val: return val.get('#text') elif all(key == 'id' or key.startswith('@') for key in val): - # this implies the XML element was empty since all keys are from attributes + # this implies the XML element was empty since all + # keys are from attributes return '' return val diff --git a/commcare_export/excel_query.py b/commcare_export/excel_query.py index e2279033..12f6d980 100644 --- a/commcare_export/excel_query.py +++ b/commcare_export/excel_query.py @@ -1,16 +1,25 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes -import re -from collections import defaultdict, namedtuple - -from jsonpath_rw.lexer import JsonPathLexerError -from six.moves import xrange +from __future__ import ( + absolute_import, + division, + generators, + nested_scopes, + print_function, + unicode_literals, +) -from jsonpath_rw import jsonpath -from jsonpath_rw.parser import parse as parse_jsonpath +from collections import defaultdict, namedtuple -from commcare_export.exceptions import LongFieldsException, MissingColumnException, ReservedTableNameException +from commcare_export.exceptions import ( + LongFieldsException, + MissingColumnException, + ReservedTableNameException, +) +from commcare_export.jsonpath_utils import split_leftmost from commcare_export.map_format import compile_map_format_via from commcare_export.minilinq import * +from jsonpath_ng import jsonpath +from jsonpath_ng.parser import parse as parse_jsonpath + def take_while(pred, iterator): for v in iterator: @@ -19,6 +28,7 @@ def take_while(pred, iterator): else: return + def drop_while(pred, iterator): for v in iterator: if not pred(v): @@ -28,66 +38,107 @@ def drop_while(pred, iterator): for v in iterator: yield v + def without_empty_tail(cells): """ Returns the prefix of a column that is not entirely empty. """ - return list(reversed(list(drop_while(lambda v: (not v) or (not v.value), reversed(cells))))) + return list( + reversed( + list( + drop_while( + lambda v: (not v) or (not v.value), reversed(cells) + ) + ) + ) + ) + def map_value(mappings_sheet, mapping_name, source_value): - "From the mappings_sheet, replaces the source_value with appropriate output value" + """ + From the mappings_sheet, replaces the source_value with appropriate + output value + """ return source_value def get_column_by_name(worksheet, column_name): # columns and rows are indexed from 1 - for col in xrange(1, worksheet.max_column + 1): + for col in range(1, worksheet.max_column + 1): value = worksheet.cell(row=1, column=col).value - value = value.lower() if value else value + value = value.lower().strip() if value else value if column_name == value: return without_empty_tail([ - worksheet.cell(row=i, column=col) for i in xrange(2, worksheet.max_row + 1) + worksheet.cell(row=i, column=col) + for i in range(2, worksheet.max_row + 1) ]) def get_columns_by_prefix(worksheet, column_prefix): # columns and rows are indexed from 1 - for col in xrange(1, worksheet.max_column + 1): + for col in range(1, worksheet.max_column + 1): value = worksheet.cell(row=1, column=col).value if value and value.lower().startswith(column_prefix): yield value, without_empty_tail([ - worksheet.cell(row=i, column=col) for i in xrange(2, worksheet.max_row + 1) + worksheet.cell(row=i, column=col) + for i in range(2, worksheet.max_row + 1) ]) def compile_mappings(worksheet): mapping_names = get_column_by_name(worksheet, "mapping name") - sources = extended_to_len(len(mapping_names), get_column_by_name(worksheet, "source")) - destinations = extended_to_len(len(mapping_names), get_column_by_name(worksheet, "destination")) + sources = extended_to_len( + len(mapping_names), get_column_by_name(worksheet, "source") + ) + destinations = extended_to_len( + len(mapping_names), get_column_by_name(worksheet, "destination") + ) mappings = defaultdict(lambda: defaultdict(lambda: None)) - - for mapping_name, source, dest in zip(mapping_names, sources, destinations): + + for mapping_name, source, dest in zip( + mapping_names, sources, destinations + ): if mapping_name and source: - mappings[mapping_name.value][source.value] = dest.value if dest else None + mappings[mapping_name.value][source.value + ] = dest.value if dest else None return mappings + def compile_filters(worksheet, mappings=None): - filter_names = [cell.value for cell in get_column_by_name(worksheet, 'filter name') or []] + filter_names = [ + cell.value + for cell in get_column_by_name(worksheet, 'filter name') or [] + ] if not filter_names: return [] - filter_values = extended_to_len(len(filter_names), [cell.value for cell in get_column_by_name(worksheet, 'filter value') or []]) - return zip(filter_names, filter_values) + filter_values = extended_to_len( + len(filter_names), + [ + cell.value + for cell in get_column_by_name(worksheet, 'filter value') or [] + ] + ) + # Preserve values of duplicate filter names. Results in an OR filter. + # e.g. {'type': ['person'], 'owner_id': ['abc123', 'def456']} + filters = defaultdict(list) + for k, v in zip(filter_names, filter_values): + filters[k].append(v) + return filters + def extended_to_len(desired_len, some_list, value=None): - return [some_list[i] if i < len(some_list) else value - for i in xrange(0, desired_len)] + return [ + some_list[i] if i < len(some_list) else value + for i in range(0, desired_len) + ] def _get_safe_source_field(source_field): + def _safe_node(node): try: parse_jsonpath(node) @@ -110,11 +161,22 @@ def _safe_node(node): return Reference(source_field) -def compile_field(field, source_field, alternate_source_fields=None, map_via=None, format_via=None, mappings=None): +def compile_field( + field, + source_field, + alternate_source_fields=None, + map_via=None, + format_via=None, + mappings=None +): expr = _get_safe_source_field(source_field) if alternate_source_fields: - expr = Apply(Reference('or'), expr, *[Reference(alt_field) for alt_field in alternate_source_fields]) + expr = Apply( + Reference('or'), + expr, + *[Reference(alt_field) for alt_field in alternate_source_fields] + ) if map_via: expr = compile_map_format_via(expr, map_via) @@ -129,15 +191,28 @@ def compile_field(field, source_field, alternate_source_fields=None, map_via=Non def compile_mapped_field(field_mappings, field_expression): # quote the ref in case it has special chars - quoted_field = Apply(Reference('join'), Literal(''), Literal('"'), field_expression, Literal('"')) + quoted_field = Apply( + Reference('join'), + Literal(''), + Literal('"'), + field_expression, + Literal('"') + ) # produce the mapping reference i.e. 'mapping."X"' - mapping_ref = Apply(Reference('join'), Literal('.'), Literal('mapping'), quoted_field) + mapping_ref = Apply( + Reference('join'), Literal('.'), Literal('mapping'), quoted_field + ) # apply the reference to the field mappings to get the final value - mapped_value = FlatMap(source=Literal([field_mappings]), body=Reference(mapping_ref), name='mapping') + mapped_value = FlatMap( + source=Literal([field_mappings]), + body=Reference(mapping_ref), + name='mapping' + ) return Apply(Reference('default'), mapped_value, field_expression) def _get_alternate_source_fields_from_csv(worksheet, num_fields): + def _clean_csv_field(field): if field and field.value: return [val.strip() for val in field.value.split(',')] @@ -150,10 +225,14 @@ def _clean_csv_field(field): def _get_alternate_source_fields_from_columns(worksheet, num_fields): - matching_columns = sorted(get_columns_by_prefix(worksheet, 'alternate source field'), key=lambda x: x[0]) + matching_columns = sorted( + get_columns_by_prefix(worksheet, 'alternate source field'), + key=lambda x: x[0] + ) alt_source_cols = [ - extended_to_len(num_fields, [cell.value if cell else cell for cell in alt_col]) - for col_name, alt_col in matching_columns + extended_to_len( + num_fields, [cell.value if cell else cell for cell in alt_col] + ) for (col_name, alt_col) in matching_columns ] # transpose columns to rows alt_srouce_fields = map(list, zip(*alt_source_cols)) @@ -174,13 +253,26 @@ def compile_fields(worksheet, mappings=None): if not fields: return [] - source_fields = extended_to_len(len(fields), get_column_by_name(worksheet, 'source field') or []) - map_vias = extended_to_len(len(fields), get_column_by_name(worksheet, 'map via') or []) - format_vias = extended_to_len(len(fields), get_column_by_name(worksheet, 'format via') or []) + source_fields = extended_to_len( + len(fields), + get_column_by_name(worksheet, 'source field') or [] + ) + map_vias = extended_to_len( + len(fields), + get_column_by_name(worksheet, 'map via') or [] + ) + format_vias = extended_to_len( + len(fields), + get_column_by_name(worksheet, 'format via') or [] + ) - alternate_source_fields = get_alternate_source_fields(worksheet, len(fields)) + alternate_source_fields = get_alternate_source_fields( + worksheet, len(fields) + ) - args = zip(fields, source_fields, alternate_source_fields, map_vias, format_vias) + args = zip( + fields, source_fields, alternate_source_fields, map_vias, format_vias + ) return [ compile_field( field=field.value, @@ -190,39 +282,36 @@ def compile_fields(worksheet, mappings=None): format_via=format_via.value if format_via else None, mappings=mappings ) - for field, source_field, alt_source_fields, map_via, format_via in args + for (field, source_field, alt_source_fields, map_via, format_via) + in args ] -def split_leftmost(jsonpath_expr): - if isinstance(jsonpath_expr, jsonpath.Child): - further_leftmost, rest = split_leftmost(jsonpath_expr.left) - return further_leftmost, rest.child(jsonpath_expr.right) - elif isinstance(jsonpath_expr, jsonpath.Descendants): - further_leftmost, rest = split_leftmost(jsonpath_expr.left) - return further_leftmost, jsonpath.Descendants(rest, jsonpath_expr.right) - else: - return (jsonpath_expr, jsonpath.This()) - -def compile_source(worksheet): +def compile_source(worksheet, value_or_root=False): """ - Compiles just the part of the Excel Spreadsheet that - indicates the API endpoint to hit along with optional filters - and an optional JSONPath within that endpoint, + Compiles just the part of the Excel Spreadsheet that indicates the + API endpoint to hit along with optional filters and an optional + JSONPath within that endpoint, For example, this spreadsheet - + Data Source Filter Name Filter Value Include Referenced Items ----------------------------- ------------ ------------------ -------------------------- form[*].form.child_questions app_id cases xmlns.exact - Should fetch from api/form?app_id=&xmlns.exact=&cases__full=true - and then iterate (FlatMap) over all child questions. + Should fetch from api/form?app_id=&xmlns.exact=&cases__full=true and then iterate (FlatMap) over all child + questions. + + :return: tuple of the 'data source' expression and the 'root doc + expression'. + + 'data source': The MiniLinq that calls 'api_data' function to + get data from CommCare - :return: tuple of the 'data source' expression and the 'root doc expression'. - 'data source': The MiniLinq that calls 'api_data' function to get data from CommCare - 'root doc expression': The MiniLinq that is applied to each doc, can be None. + 'root doc expression': The MiniLinq that is applied to each doc, + can be None. """ data_source_column = get_column_by_name(worksheet, 'data source') @@ -230,13 +319,22 @@ def compile_source(worksheet): raise Exception('Sheet has no "Data Source" column.') data_source_str = data_source_column[0].value filters = compile_filters(worksheet) - include_referenced_items = [cell.value for cell in (get_column_by_name(worksheet, 'include referenced items') or [])] + include_referenced_items = [ + cell.value for cell in + (get_column_by_name(worksheet, 'include referenced items') or []) + ] - data_source, data_source_jsonpath = split_leftmost(parse_jsonpath(data_source_str)) - maybe_redundant_slice, remaining_jsonpath = split_leftmost(data_source_jsonpath) + data_source, data_source_jsonpath = split_leftmost( + parse_jsonpath(data_source_str) + ) + maybe_redundant_slice, remaining_jsonpath = split_leftmost( + data_source_jsonpath + ) - # The leftmost _must_ be of type Fields with one field and will pull out the first field - if not isinstance(data_source, jsonpath.Fields) or len(data_source.fields) > 1: + # The leftmost _must_ be of type Fields with one field and will pull + # out the first field + if not isinstance(data_source, + jsonpath.Fields) or len(data_source.fields) > 1: raise Exception('Bad value for data source: %s' % str(data_source)) data_source = data_source.fields[0] @@ -244,29 +342,69 @@ def compile_source(worksheet): if isinstance(maybe_redundant_slice, jsonpath.Slice): data_source_jsonpath = remaining_jsonpath - api_query_args = [Reference("api_data"), Literal(data_source), Reference('checkpoint_manager')] - + api_query_args = [ + Reference("api_data"), + Literal(data_source), + Reference('checkpoint_manager') + ] + if not filters: if include_referenced_items: - api_query_args.append(Literal(None)) # Pad the argument list if we have further args; keeps tests and user code more readable at the expense of this conditional + # Pad the argument list if we have further args; keeps tests + # and user code more readable at the expense of this + # conditional + api_query_args.append(Literal(None)) else: - api_query_args.append(Literal(dict(filters))) + api_query_args.append(Literal(filters)) if include_referenced_items: api_query_args.append(Literal(include_referenced_items)) api_query = Apply(*api_query_args) - if data_source_jsonpath is None or isinstance(data_source_jsonpath, jsonpath.This) or isinstance(data_source_jsonpath, jsonpath.Root): + if ( + data_source_jsonpath is None + or isinstance(data_source_jsonpath, jsonpath.This) + or isinstance(data_source_jsonpath, jsonpath.Root) + ): return data_source, api_query, None else: - return data_source, api_query, Reference(str(data_source_jsonpath)) + if value_or_root: + # if the jsonpath doesn't yield a value yield the root document + expr = get_value_or_root_expression(data_source_jsonpath) + else: + expr = Reference(str(data_source_jsonpath)) + return data_source, api_query, expr + + +def get_value_or_root_expression(value_expression): + """ + Return expression used when iterating over a nested document but + also wanting a record if the value expression returns an empty + result. + """ + + # We add a bind here so that in JsonPathEnv we can restrict + # expressions to only those that reference the root. That prevents + # us from mistakenly getting values from the root that happen to + # have the same name as those in the child. + root_expr = Bind("__root_only", Literal(True), Reference("$")) + return Apply( + Reference('_or_raw'), Reference(str(value_expression)), root_expr + ) + # If the source is expected to provide a column, then require that it is # already present or can be added without conflicting with an existing # column. -def require_column_in_sheet(sheet_name, data_source, table_name, output_headings, - output_fields, column_enforcer): +def require_column_in_sheet( + sheet_name, + data_source, + table_name, + output_headings, + output_fields, + column_enforcer, +): # Check for conflicting use of column name. extend_fields = True @@ -281,24 +419,42 @@ def require_column_in_sheet(sheet_name, data_source, table_name, output_headings extend_fields = False continue else: - raise Exception('Field name "{}" conflicts with an internal name.'.format(required_column.name.v)) + raise Exception( + 'Field name "{}" conflicts with an internal name.' + .format(required_column.name.v) + ) if extend_fields: - headings = [Literal(output_heading.value) - for output_heading in output_headings] + [required_column.name] - body = List(output_fields + - [compile_field(field=required_column.name, - source_field=required_column.source)]) + headings = [ + Literal(output_heading.value) for output_heading in output_headings + ] + [required_column.name] + body = List( + output_fields + [ + compile_field( + field=required_column.name, + source_field=required_column.source + ) + ] + ) else: - headings = [Literal(output_heading.value) - for output_heading in output_headings] + headings = [ + Literal(output_heading.value) for output_heading in output_headings + ] body = List(output_fields) return (headings, body) -def parse_sheet(worksheet, mappings=None, column_enforcer=None): + +def parse_sheet( + worksheet, + mappings=None, + column_enforcer=None, + value_or_root=False, +): mappings = mappings or {} - data_source, source_expr, root_doc_expr = compile_source(worksheet) + data_source, source_expr, root_doc_expr = compile_source( + worksheet, value_or_root + ) table_name_column = get_column_by_name(worksheet, 'table name') if table_name_column: @@ -315,20 +471,25 @@ def parse_sheet(worksheet, mappings=None, column_enforcer=None): source = source_expr body = None else: - # note: if we want to add data types to the columns added by the column_enforcer - # this will have to conditionally move into the if/else below + # note: if we want to add data types to the columns added by the + # column_enforcer this will have to conditionally move into the + # if/else below data_types = [Literal(data_type.value) for data_type in output_types] if column_enforcer is not None: - (headings, body) = require_column_in_sheet(worksheet.title, - data_source, - output_table_name, - output_headings, - output_fields, - column_enforcer) + (headings, body) = require_column_in_sheet( + worksheet.title, + data_source, + output_table_name, + output_headings, + output_fields, + column_enforcer + ) source = source_expr else: - headings = [Literal(output_heading.value) - for output_heading in output_headings] + headings = [ + Literal(output_heading.value) + for output_heading in output_headings + ] source = source_expr body = List(output_fields) @@ -339,27 +500,48 @@ def parse_sheet(worksheet, mappings=None, column_enforcer=None): body, root_doc_expr, data_types, + data_source, ) -class SheetParts(namedtuple('SheetParts', 'name headings source body root_expr data_types')): - def __new__(cls, name, headings, source, body, root_expr=None, data_types=None): - data_types = data_types or [] - return super(SheetParts, cls).__new__(cls, name, headings, source, body, root_expr, data_types) +class SheetParts(namedtuple( + 'SheetParts', + 'name headings source body root_expr data_types data_source' +)): + + def __new__( + cls, + name, + headings, + source, + body, + root_expr=None, + data_types=None, + data_source=None + ): + return super().__new__( + cls, + name, + headings, + source, + body, + root_expr, + data_types or [], + data_source + ) @property def columns(self): - return [ - col.v for col in self.headings - ] + return [col.v for col in self.headings] -def parse_workbook(workbook, column_enforcer=None): +def parse_workbook(workbook, column_enforcer=None, value_or_root=False): """ Returns a MiniLinq corresponding to the Excel configuration, which consists of the following sheets: - 1. "Mappings" a sheet with three columns that defines simple lookup table functions + 1. "Mappings" a sheet with three columns that defines simple lookup + table functions: A. MappingName - the name by which this mapping is referenced B. Source - the value to match C. Destination - the value to return @@ -372,14 +554,23 @@ def parse_workbook(workbook, column_enforcer=None): mappings_sheet = None mappings = compile_mappings(mappings_sheet) if mappings_sheet else None - emit_sheets = [sheet_name for sheet_name in workbook.sheetnames if sheet_name != 'Mappings'] + emit_sheets = [ + sheet_name for sheet_name in workbook.sheetnames + if sheet_name != 'Mappings' + ] parsed_sheets = [] for sheet in emit_sheets: try: - sheet_parts = parse_sheet(workbook[sheet], mappings, column_enforcer) + sheet_parts = parse_sheet( + workbook[sheet], mappings, column_enforcer, value_or_root + ) except Exception as e: - logger.warning('Ignoring sheet "{}": {}'.format(sheet, str(e))) + msg = 'Ignoring sheet "{}": {}'.format(sheet, str(e)) + if logger.isEnabledFor(logging.DEBUG): + logger.exception(msg) + else: + logger.warning(msg) continue parsed_sheets.append(sheet_parts) @@ -391,7 +582,8 @@ def compile_queries(parsed_sheets, missing_value, combine_emits): # group sheets by source sheets_by_source = [] for sheet in parsed_sheets: - # Not easy to implement hashing on MiniLinq objects so can't use a dict + # Not easy to implement hashing on MiniLinq objects so can't use + # a dict for source, sheets in sheets_by_source: if sheet.source == source: sheets.append(sheet) @@ -403,7 +595,9 @@ def compile_queries(parsed_sheets, missing_value, combine_emits): for source, sheets in sheets_by_source: if len(sheets) > 1: if combine_emits: - queries.append(get_multi_emit_query(source, sheets, missing_value)) + queries.append( + get_multi_emit_query(source, sheets, missing_value) + ) else: queries.extend([ get_single_emit_query(sheet, missing_value) @@ -415,53 +609,56 @@ def compile_queries(parsed_sheets, missing_value, combine_emits): def get_multi_emit_query(source, sheets, missing_value): - """Multiple `Emit` expressions using the same data source. - For this we reverse the `Map` so that we apply each `Emit` - repeatedly for each doc produced by the data source. + """ + Multiple `Emit` expressions using the same data source. For this we + reverse the `Map` so that we apply each `Emit` repeatedly for each + doc produced by the data source. """ emits = [] - multi_query = Filter( # the filter here is to prevent accumulating a `[None]` value for each doc - predicate=Apply( - Reference("filter_empty"), - Reference("$") - ), - source=Map( - source=source, - body=List(emits) - ) + # the filter here is to prevent accumulating a `[None]` value for + # each doc + multi_query = Filter( + predicate=Apply(Reference("filter_empty"), Reference("$")), + source=Map(source=source, body=List(emits)) ) for sheet in sheets: - # if there is no root expression then we just reference the whole document with `this` + # if there is no root expression then we just reference the + # whole document with `this` root_expr = sheet.root_expr or Reference("`this`") emits.append( Emit( table=sheet.name, headings=sheet.headings, - source=Map( - source=root_expr, - body=sheet.body - ), + source=Map(source=root_expr, body=sheet.body), missing_value=missing_value, data_types=sheet.data_types, ) ) table_names = [sheet.name for sheet in sheets] - return Bind('checkpoint_manager', Apply(Reference('get_checkpoint_manager'), Literal(table_names)), multi_query) + data_source = sheets[ + 0].data_source # sheets will all have the same datasource + return Bind( + 'checkpoint_manager', + Apply( + Reference('get_checkpoint_manager'), + Literal(data_source), + Literal(table_names) + ), + multi_query + ) def get_single_emit_query(sheet, missing_value): - """Single `Emit` for the data source to we can just - apply the `Emit` once with the source expression being - the data source. """ + Single `Emit` for the data source to we can just apply the `Emit` + once with the source expression being the data source. + """ + def _get_source(source, root_expr): if root_expr: - return FlatMap( - source=source, - body=root_expr - ) + return FlatMap(source=source, body=root_expr) else: return source @@ -469,13 +666,20 @@ def _get_source(source, root_expr): table=sheet.name, headings=sheet.headings, source=Map( - source=_get_source(sheet.source, sheet.root_expr), - body=sheet.body + source=_get_source(sheet.source, sheet.root_expr), body=sheet.body ), missing_value=missing_value, data_types=sheet.data_types, ) - return Bind('checkpoint_manager', Apply(Reference('get_checkpoint_manager'), Literal([sheet.name])), emit) + return Bind( + 'checkpoint_manager', + Apply( + Reference('get_checkpoint_manager'), + Literal(sheet.data_source), + Literal([sheet.name]) + ), + emit + ) def check_field_length(parsed_sheets, max_column_length): @@ -499,14 +703,24 @@ def check_columns(parsed_sheets, columns): if errors_by_sheet: raise MissingColumnException(errors_by_sheet) + blacklisted_tables = [] + + def blacklist(table_name): blacklisted_tables.append(table_name) -def get_queries_from_excel(workbook, missing_value=None, combine_emits=False, - max_column_length=None, required_columns=None, - column_enforcer=None): - parsed_sheets = parse_workbook(workbook, column_enforcer) + +def get_queries_from_excel( + workbook, + missing_value=None, + combine_emits=False, + max_column_length=None, + required_columns=None, + column_enforcer=None, + value_or_root=False +): + parsed_sheets = parse_workbook(workbook, column_enforcer, value_or_root) for sheet in parsed_sheets: if sheet.name in blacklisted_tables: raise ReservedTableNameException(sheet.name) diff --git a/commcare_export/exceptions.py b/commcare_export/exceptions.py index 86fde5da..f4e7dbd2 100644 --- a/commcare_export/exceptions.py +++ b/commcare_export/exceptions.py @@ -3,6 +3,7 @@ class DataExportException(Exception): class LongFieldsException(DataExportException): + def __init__(self, long_fields, max_length): self.long_fields = long_fields self.max_length = max_length @@ -12,43 +13,52 @@ def message(self): message = '' for table, headers in self.long_fields.items(): message += ( - 'Table "{}" has field names longer than the maximum allowed for this database ({}):\n'.format( - table, self.max_length - )) + f'Table "{table}" has field names longer than the maximum ' + f'allowed for this database ({self.max_length}):\n' + ) for header in headers: message += ' {}\n'.format(header) - message += '\nPlease adjust field names to be within the maximum length limit of {}'.format(self.max_length) + message += ( + '\nPlease adjust field names to be within the maximum length ' + f'limit of {self.max_length}' + ) return message class MissingColumnException(DataExportException): + def __init__(self, errors_by_sheet): self.errors_by_sheet = errors_by_sheet @property def message(self): lines = [ - 'Table "{}" is missing required columns: "{}"'.format( - sheet, '", "'.join(missing_cols) - ) for sheet, missing_cols in self.errors_by_sheet.items() + 'Sheet "{}" is missing definitions for required fields: "{}"' + .format(sheet, '", "'.join(missing_cols)) + for (sheet, missing_cols) in self.errors_by_sheet.items() ] return '\n'.join(lines) class MissingQueryFileException(DataExportException): + def __init__(self, query_file): self.query_file = query_file @property def message(self): - return 'Query file not found: {}'.format(self.query_file) + return f'Query file not found: {self.query_file}' class ReservedTableNameException(DataExportException): + def __init__(self, conflicting_name): self.conflicting_name = conflicting_name @property def message(self): - return 'Table name "{}" conflicts with an internal table name. Please export to a different table.'.format(self.conflicting_name) + return ( + f'Table name "{self.conflicting_name}" conflicts with an internal ' + f'table name. Please export to a different table.' + ) diff --git a/commcare_export/jsonpath_utils.py b/commcare_export/jsonpath_utils.py new file mode 100644 index 00000000..e694956e --- /dev/null +++ b/commcare_export/jsonpath_utils.py @@ -0,0 +1,14 @@ +from jsonpath_ng import jsonpath + + +def split_leftmost(jsonpath_expr): + if isinstance(jsonpath_expr, jsonpath.Child): + further_leftmost, rest = split_leftmost(jsonpath_expr.left) + return further_leftmost, rest.child(jsonpath_expr.right) + elif isinstance(jsonpath_expr, jsonpath.Descendants): + further_leftmost, rest = split_leftmost(jsonpath_expr.left) + return further_leftmost, jsonpath.Descendants( + rest, jsonpath_expr.right + ) + else: + return jsonpath_expr, jsonpath.This() diff --git a/commcare_export/location_info_provider.py b/commcare_export/location_info_provider.py index 8911a523..e71f2b71 100644 --- a/commcare_export/location_info_provider.py +++ b/commcare_export/location_info_provider.py @@ -1,16 +1,18 @@ -import logging -from commcare_export.misc import unwrap_val from commcare_export.commcare_minilinq import SimplePaginator +from commcare_export.misc import unwrap_val +from commcare_export import get_logger -logger = logging.getLogger(__name__) +logger = get_logger(__file__) -# LocationInfoProvider uses the /location_type/ endpoint of the API -# to retrieve location type data, stores that information in a dictionary +# LocationInfoProvider uses the /location_type/ endpoint of the API to +# retrieve location type data, stores that information in a dictionary # keyed by resource URI and provides the method 'get_location_info' to # extract values from the dictionary. + class LocationInfoProvider: + def __init__(self, api_client, page_size): self._api_client = api_client self._page_size = page_size @@ -43,8 +45,9 @@ def get_location_types(self): paginator = SimplePaginator('location_type', self._page_size) paginator.init(None, False, None) location_type_dict = {} - for row in self._api_client.iterate('location_type', paginator, - {'limit': self._page_size}): + for row in self._api_client.iterate( + 'location_type', paginator, {'limit': self._page_size} + ): location_type_dict[row['resource_uri']] = row return location_type_dict @@ -60,8 +63,9 @@ def get_location_hierarchy(self): # Extract every location, its type and its parent location_data = {} - for row in self._api_client.iterate('location', paginator, - {'limit': self._page_size}): + for row in self._api_client.iterate( + 'location', paginator, {'limit': self._page_size} + ): location_data[row['resource_uri']] = { 'location_id': row['location_id'], 'location_type': row['location_type'], @@ -70,19 +74,24 @@ def get_location_hierarchy(self): # Build a map from location resource_uri to a map from # location_type_code to ancestor location id. - ancestors = {} # includes location itself + ancestors = {} # includes location itself for resource_uri in location_data: loc_uri = resource_uri type_code_to_id = {} while loc_uri is not None: if loc_uri not in location_data: - logger.warning('Unknown location referenced: {}'.format(loc_uri)) + logger.warning( + 'Unknown location referenced: {}'.format(loc_uri) + ) break loc_data = location_data[loc_uri] loc_type = loc_data['location_type'] if loc_type not in self.location_types: - logger.warning('Unknown location type referenced: {}'.format(loc_type)) + logger.warning( + 'Unknown location type referenced: {}' + .format(loc_type) + ) break type_code = self.location_types[loc_type]['code'] @@ -90,5 +99,3 @@ def get_location_hierarchy(self): loc_uri = loc_data['parent'] ancestors[resource_uri] = type_code_to_id return ancestors - - diff --git a/commcare_export/map_format.py b/commcare_export/map_format.py index 56a7c6e4..8f413426 100644 --- a/commcare_export/map_format.py +++ b/commcare_export/map_format.py @@ -1,6 +1,6 @@ import re -from commcare_export.minilinq import Literal, Apply, Reference +from commcare_export.minilinq import Apply, Literal, Reference SELECTED_AT = 'selected-at' SELECTED = 'selected' @@ -9,6 +9,7 @@ class ParsingException(Exception): + def __init__(self, message): self.message = message @@ -25,7 +26,9 @@ def parse_function_arg(slug, expr_string): matches = re.match(regex, expr_string) if not matches: - raise ParsingException('Error: Unable to parse: {}'.format(expr_string)) + raise ParsingException( + 'Error: Unable to parse: {}'.format(expr_string) + ) return matches.groups()[0] @@ -35,7 +38,10 @@ def parse_selected_at(value_expr, selected_at_expr_string): try: index = int(index) except ValueError: - return Literal('Error: selected-at index must be an integer: {}'.format(selected_at_expr_string)) + return Literal( + 'Error: selected-at index must be an integer: {}' + .format(selected_at_expr_string) + ) return Apply(Reference(SELECTED_AT), value_expr, Literal(index)) @@ -49,7 +55,10 @@ def parse_template(value_expr, format_expr_string): args_string = parse_function_arg(TEMPLATE, format_expr_string) args = [arg.strip() for arg in args_string.split(',') if arg.strip()] if len(args) < 1: - return Literal('Error: template function requires the format template: {}'.format(format_expr_string)) + return Literal( + 'Error: template function requires the format template: ' + f'{format_expr_string}' + ) template = args.pop(0) if args: args = [Reference(arg) for arg in args] @@ -63,7 +72,10 @@ def parse_substr(value_expr, substr_expr_string): regex = r'^\s*(\d+)\s*,\s*(\d+)\s*$' matches = re.match(regex, args_string) if not matches or len(matches.groups()) != 2: - raise ParsingException('Error: both substr arguments must be non-negative integers: {}'.format(substr_expr_string)) + raise ParsingException( + 'Error: both substr arguments must be non-negative integers: ' + f'{substr_expr_string}' + ) # These conversions should always succeed after a pattern match. start = int(matches.groups()[0]) diff --git a/commcare_export/minilinq.py b/commcare_export/minilinq.py index 66aa79f2..b0d476f3 100644 --- a/commcare_export/minilinq.py +++ b/commcare_export/minilinq.py @@ -1,16 +1,15 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes -import logging - -import six -from six.moves import map +from typing import Any, Dict +from typing import List as ListType +from typing import Optional +from commcare_export.env import Env from commcare_export.misc import unwrap, unwrap_val - from commcare_export.repeatable_iterator import RepeatableIterator - from commcare_export.specs import TableSpec +from commcare_export import get_logger + +logger = get_logger(__file__) -logger = logging.getLogger(__name__) class MiniLinq(object): """ @@ -18,16 +17,12 @@ class MiniLinq(object): for dispatching parsing, etc. """ - def __init__(self, *args, **kwargs): + def eval(self, env: Env) -> Any: raise NotImplementedError() - def eval(self, env): - "( env: object(bindings: {str: ??}, writer: Writer) )-> ??" - raise NotImplementedError() - #### Factory methods #### - _node_classes = {} + _node_classes: Dict[str, 'MiniLinq'] = {} @classmethod def register(cls, clazz, slug=None): @@ -36,21 +31,23 @@ def register(cls, clazz, slug=None): @classmethod def from_jvalue(cls, jvalue): """ - The term `jvalue` is code for "the output of a JSON deserialization". This - module does not actually care about JSON, which is concrete syntax, but - only the corresponding data model of lists and string-indexed dictionaries. + The term `jvalue` is code for "the output of a JSON + deserialization". This module does not actually care about + JSON, which is concrete syntax, but only the corresponding data + model of lists and string-indexed dictionaries. - (since this data might never actually be a string, that layer is handled elsewhere) + (since this data might never actually be a string, that layer is + handled elsewhere) """ - # This is a bit wonky, but this method really should not be inherited. - # So if we end up here from a subclass, it is broken. + # This is a bit wonky, but this method really should not be + # inherited. So if we end up here from a subclass, it is broken. if not issubclass(MiniLinq, cls): raise NotImplementedError() - - if isinstance(jvalue, six.string_types): + + if isinstance(jvalue, str): return jvalue - + elif isinstance(jvalue, list): # Leverage for literal lists of data in the code return [MiniLinq.from_jvalue(v) for v in jvalue] @@ -59,25 +56,34 @@ def from_jvalue(cls, jvalue): # Dictionaries are reserved; they must always have exactly # one entry and it must be the AST node class if len(jvalue.values()) != 1: - raise ValueError('JValue serialization of AST contains dict with number of slugs != 1') + raise ValueError( + 'JValue serialization of AST contains dict with number of slugs != 1' + ) slug = list(jvalue.keys())[0] if slug not in cls._node_classes: - raise ValueError('JValue serialization of AST contains unknown node type: %s' % slug) + raise ValueError( + 'JValue serialization of AST contains unknown node type: %s' + % slug + ) return cls._node_classes[slug].from_jvalue(jvalue) + def to_jvalue(self): + raise NotImplementedError() + class Reference(MiniLinq): """ - An MiniLinq referencing a datum or data. It is flexible - about what the type of the environment is, but it must - support using these as keys. + An MiniLinq referencing a datum or data. It is flexible about what + the type of the environment is, but it must support using these as + keys. """ + def __init__(self, ref): - self.ref = ref #parse_jsonpath(ref) #ref + self.ref = ref #parse_jsonpath(ref) #ref self.nested = isinstance(self.ref, MiniLinq) - + def eval(self, env): if self.nested: ref = self.ref.eval(env) @@ -100,14 +106,15 @@ def __repr__(self): class Literal(MiniLinq): """ - An MiniLinq wrapper around a python value. Returns exactly the - value given to it. Note: when going to/from jvalue the - contents are left alone, so it can be _used_ with a non-JSON - encodable value, but cannot be encoded. + An MiniLinq wrapper around a python value. Returns exactly the value + given to it. Note: when going to/from jvalue the contents are left + alone, so it can be _used_ with a non-JSON encodable value, but + cannot be encoded. """ + def __init__(self, v): self.v = v - + def eval(self, env): return self.v @@ -127,15 +134,13 @@ def to_jvalue(self): class Bind(MiniLinq): """ - Binds the results of an expression to a new name. Will be useful - in writing exports by hand or debugging, and maybe for efficiency - if it de-dupes computation (but generally exports will be - expected to be too large to store, so it'll be re-run on each - access. + Binds the results of an expression to a new name. Will be useful in + writing exports by hand or debugging, and maybe for efficiency if it + de-dupes computation (but generally exports will be expected to be + too large to store, so it'll be re-run on each access. """ - def __init__(self, name, value, body): - "(str, MiniLinq, MiniLinq) -> MiniLinq" + def __init__(self, name: str, value: MiniLinq, body: MiniLinq) -> None: self.name = name self.value = value self.body = body @@ -144,25 +149,32 @@ def eval(self, env): return self.body.eval(env.bind(self.name, self.value.eval(env))) def __eq__(self, other): - return isinstance(other, Bind) and self.name == other.name and self.value == other.value and self.body == other.body + return isinstance( + other, Bind + ) and self.name == other.name and self.value == other.value and self.body == other.body def __repr__(self): - return '%s(name=%r, value=%r, body=%r)' % (self.__class__.__name__, self.name, self.value, self.body) + return '%s(name=%r, value=%r, body=%r)' % ( + self.__class__.__name__, self.name, self.value, self.body + ) @classmethod def from_jvalue(cls, jvalue): fields = jvalue['Bind'] - return cls(name=fields['name'], - value=MiniLinq.from_jvalue(fields['value']), - body=MiniLinq.from_jvalue(fields['body'])) + return cls( + name=fields['name'], + value=MiniLinq.from_jvalue(fields['value']), + body=MiniLinq.from_jvalue(fields['body']) + ) def to_jvalue(self): - return {'Bind':{'name': self.name, - 'value': self.value.to_jvalue(), - 'body': self.body.to_jvalue()}} - - def __repr__(self): - return '%s(name=%r, value=%r, body=%r)' % (self.__class__.__name__, self.name, self.value, self.body) + return { + 'Bind': { + 'name': self.name, + 'value': self.value.to_jvalue(), + 'body': self.body.to_jvalue() + } + } class Filter(MiniLinq): @@ -170,8 +182,12 @@ class Filter(MiniLinq): Just what it sounds like """ - def __init__(self, source, predicate, name=None): - "(MiniLinq, MiniLinq, var?) -> MiniLinq" + def __init__( + self, + source: MiniLinq, + predicate: MiniLinq, + name: Optional[str] = None + ) -> None: self.source = source self.name = name self.predicate = predicate @@ -179,47 +195,63 @@ def __init__(self, source, predicate, name=None): def eval(self, env): source_result = self.source.eval(env) - def iterate(env=env, source_result=source_result): # Python closure workaround + # Python closure workaround + def iterate( + env_=env, + source_result_=source_result, + ): if self.name: - for item in source_result: - if self.predicate.eval(env.bind(self.name, item)): + for item in source_result_: + if self.predicate.eval(env_.bind(self.name, item)): yield item else: - for item in source_result: - if self.predicate.eval(env.replace(item)): + for item in source_result_: + if self.predicate.eval(env_.replace(item)): yield item return RepeatableIterator(iterate) def __eq__(self, other): - return isinstance(other, Filter) and self.source == other.source and self.name == other.name and self.predicate == other.predicate + return ( + isinstance(other, Filter) and self.source == other.source + and self.name == other.name and self.predicate == other.predicate + ) @classmethod def from_jvalue(cls, jvalue): fields = jvalue['Filter'] # TODO: catch errors and give informative error messages - return cls(predicate = MiniLinq.from_jvalue(fields['predicate']), - source = MiniLinq.from_jvalue(fields['source']), - name = fields.get('name')) + return cls( + predicate=MiniLinq.from_jvalue(fields['predicate']), + source=MiniLinq.from_jvalue(fields['source']), + name=fields.get('name') + ) def to_jvalue(self): - return {'Filter': {'predicate': self.predicate.to_jvalue(), - 'source': self.source.to_jvalue(), - 'name': self.name}} + return { + 'Filter': { + 'predicate': self.predicate.to_jvalue(), + 'source': self.source.to_jvalue(), + 'name': self.name + } + } def __repr__(self): - return '%s(source=%r, name=%r, predicate=%r)' % (self.__class__.__name__, self.source, self.name, self.predicate) + return '%s(source=%r, name=%r, predicate=%r)' % ( + self.__class__.__name__, self.source, self.name, self.predicate + ) class List(MiniLinq): """ - A list of expressions, embeds the [ ... ] syntax into the - MiniLinq meta-leval + A list of expressions, embeds the [ ... ] syntax into the MiniLinq + meta-leval """ + def __init__(self, items): self.items = items - + def eval(self, env): return [item.eval(env) for item in self.items] @@ -239,26 +271,30 @@ def to_jvalue(self): class Map(MiniLinq): """ - Like the `FROM` clause of a SQL `SELECT` or jQuery's map, - binds each item from its `source` and evaluates - the body MiniLinq. - - If `name` is provided to the constructor, then instead of - replacing the environment with each row, it will just - bind the row to `name`, enabling references to the - rest of the env. + Like the `FROM` clause of a SQL `SELECT` or jQuery's map, binds each + item from its `source` and evaluates the body MiniLinq. + + If `name` is provided to the constructor, then instead of replacing + the environment with each row, it will just bind the row to `name`, + enabling references to the rest of the env. """ - def __init__(self, source, body, name=None): - "(MiniLinq, MiniLinq, var?) -> MiniLinq" + def __init__( + self, + source: MiniLinq, + body: MiniLinq, + name: Optional[str] = None + ) -> None: self.source = source self.name = name self.body = body - + def eval(self, env): source_result = self.source.eval(env) - def iterate(env=env, source_result=source_result): # Python closure workaround + def iterate( + env=env, source_result=source_result + ): # Python closure workaround if self.name: for item in source_result: yield self.body.eval(env.bind(self.name, item)) @@ -269,81 +305,106 @@ def iterate(env=env, source_result=source_result): # Python closure workaround return RepeatableIterator(iterate) def __eq__(self, other): - return isinstance(other, Map) and self.name == other.name and self.source == other.source and self.body == other.body + return ( + isinstance(other, Map) and self.name == other.name + and self.source == other.source and self.body == other.body + ) @classmethod def from_jvalue(cls, jvalue): fields = jvalue['Map'] # TODO: catch errors and give informative error messages - return cls(body = MiniLinq.from_jvalue(fields['body']), - source = MiniLinq.from_jvalue(fields['source']), - name = fields.get('name')) + return cls( + body=MiniLinq.from_jvalue(fields['body']), + source=MiniLinq.from_jvalue(fields['source']), + name=fields.get('name') + ) def to_jvalue(self): - return {'Map': {'body': self.body.to_jvalue(), - 'source': self.source.to_jvalue(), - 'name': self.name}} + return { + 'Map': { + 'body': self.body.to_jvalue(), + 'source': self.source.to_jvalue(), + 'name': self.name + } + } class FlatMap(MiniLinq): """ - Somewhat like a JOIN, but not quite. Called `SelectMany` - in LINQ and `flatMap` other languages. Obvious equivalence: - `flatMap f = flatten . map f` but so common it is useful to - have around. - - If `name` is provided to the constructor, then instead of - replacing the environment with each row, it will just - bind the row to `name`, enabling references to the - rest of the env. + Somewhat like a JOIN, but not quite. Called `SelectMany` in LINQ and + `flatMap` other languages. Obvious equivalence: `flatMap f = flatten + . map f` but so common it is useful to have around. + + If `name` is provided to the constructor, then instead of replacing + the environment with each row, it will just bind the row to `name`, + enabling references to the rest of the env. """ - def __init__(self, source, body, name=None): - "(MiniLinq, MiniLinq, var?) -> MiniLinq" + def __init__( + self, + source: MiniLinq, + body: MiniLinq, + name: Optional[str] = None + ) -> None: self.source = source self.name = name self.body = body - + def eval(self, env): source_result = self.source.eval(env) - def iterate(env=env, source_result=source_result): # Python closure workaround + # Python closure workaround + def iterate( + env_=env, + source_result_=source_result, + ): if self.name: - for item in source_result: - for result_item in self.body.eval(env.bind(self.name, item)): + for item in source_result_: + for result_item in self.body.eval( + env_.bind(self.name, item) + ): yield result_item else: - for item in source_result: - for result_item in self.body.eval(env.replace(item)): + for item in source_result_: + for result_item in self.body.eval(env_.replace(item)): yield result_item return RepeatableIterator(iterate) def __eq__(self, other): - return isinstance(other, FlatMap) and self.name == other.name and self.source == other.source and self.body == other.body - + return ( + isinstance(other, FlatMap) and self.name == other.name + and self.source == other.source and self.body == other.body + ) @classmethod def from_jvalue(cls, jvalue): fields = jvalue['FlatMap'] # TODO: catch errors and give informative error messages - return cls(body = MiniLinq.from_jvalue(fields['body']), - source = MiniLinq.from_jvalue(fields['source']), - name = fields.get('name')) + return cls( + body=MiniLinq.from_jvalue(fields['body']), + source=MiniLinq.from_jvalue(fields['source']), + name=fields.get('name') + ) def to_jvalue(self): - return {'FlatMap': {'body': self.body.to_jvalue(), - 'source': self.source.to_jvalue(), - 'name': self.name}} + return { + 'FlatMap': { + 'body': self.body.to_jvalue(), + 'source': self.source.to_jvalue(), + 'name': self.name + } + } class Apply(MiniLinq): """ Abstract syntax for function or operator application. """ - + def __init__(self, fn, *args): self.fn = fn self.args = args @@ -364,46 +425,71 @@ def eval(self, env): doc_id = 'unknown' message = e.args[0] + ( - ": Error processing document '%s'. " - "Failure to evaluating expression '%r' with arguments '%s'" - ) % (doc_id, self, args) + f": Error processing document '{doc_id}'. Failure to " + f"evaluating expression '{self!r}' with arguments '{args}'" + ) e.args = (message,) + e.args[1:] raise return result def __eq__(self, other): - return isinstance(other, Apply) and self.fn == other.fn and self.args == other.args + return ( + isinstance(other, Apply) and self.fn == other.fn + and self.args == other.args + ) @classmethod def from_jvalue(cls, jvalue): fields = jvalue['Apply'] # TODO: catch errors and give informative error messages - return cls(MiniLinq.from_jvalue(fields['fn']), - *[MiniLinq.from_jvalue(arg) for arg in fields['args']]) + return cls( + MiniLinq.from_jvalue(fields['fn']), + *[MiniLinq.from_jvalue(arg) for arg in fields['args']] + ) def to_jvalue(self): - return {'Apply': {'fn': self.fn.to_jvalue(), - 'args': [arg.to_jvalue() for arg in self.args]}} + return { + 'Apply': { + 'fn': self.fn.to_jvalue(), + 'args': [arg.to_jvalue() for arg in self.args] + } + } def __repr__(self): - return '%s(%r, *%r)' % (self.__class__.__name__, self.fn, self.args) + return f'{self.__class__.__name__}({self.fn!r}, *{self.args!r})' class Emit(MiniLinq): """ - This MiniLinq writes a whole table to whatever writer is registered in the `env`. - In practice, a table to a dict of a name, headers, and rows, so the - writer is free to do an idempotent upsert, etc. + This MiniLinq writes a whole table to whatever writer is registered + in the `env`. In practice, a table to a dict of a name, headers, + and rows, so the writer is free to do an idempotent upsert, etc. Note that it does not actually check that the number of headings is - correct, nor does it try to ensure that the things being emitted - are actually lists - it is just crashy instead. + correct, nor does it try to ensure that the things being emitted are + actually lists - it is just crashy instead. """ - def __init__(self, table, headings, source, missing_value=None, data_types=None): - "(str, [str], [MiniLinq]) -> MiniLinq" + def __init__( + self, + table: str, + headings: ListType[MiniLinq], + source: MiniLinq, + missing_value: Optional[str] = None, + data_types: Optional[ListType[Literal]] = None, + ) -> None: + """ + Initializes an ``Emit`` instance. + + :param table: The name/title of the table to be written. + :param headings: Evaluated to determine column headings. + :param source: Evaluated to determine the table rows. + :param missing_value: Denotes "no value". e.g. ``"---"`` + :param data_types: The data types of the columns. e.g. + ``[Literal('text'), Literal('date'), ...]`` + """ self.table = table self.headings = headings self.source = source @@ -423,20 +509,24 @@ def coerce_cell(self, cell): try: return self.coerce_cell_blithely(cell) except Exception: - logger.exception('Error converting value to exportable form: %r' % cell) + logger.exception( + 'Error converting value to exportable form: %r' % cell + ) return '' - + def coerce_row(self, row): return [self.coerce_cell(cell) for cell in row] def eval(self, env): rows = self.source.eval(env) - env.emit_table(TableSpec( - name=self.table, - headings=[heading.eval(env) for heading in self.headings], - rows=list(map(self.coerce_row, rows)), - data_types=[lit.v for lit in self.data_types] - )) + env.emit_table( + TableSpec( + name=self.table, + headings=[heading.eval(env) for heading in self.headings], + rows=map(self.coerce_row, rows), + data_types=[lit.v for lit in self.data_types] + ) + ) @classmethod def from_jvalue(cls, jvalue): @@ -444,30 +534,44 @@ def from_jvalue(cls, jvalue): return cls( table=fields['table'], source=MiniLinq.from_jvalue(fields['source']), - headings=[MiniLinq.from_jvalue(heading) for heading in fields['headings']], + headings=[ + MiniLinq.from_jvalue(heading) for heading in fields['headings'] + ], missing_value=fields.get('missing_value'), data_types=fields.get('data_types'), ) def to_jvalue(self): - return {'Emit': {'table': self.table, - 'headings': [heading.to_jvalue() for heading in self.headings], - 'source': self.source.to_jvalue(), - 'missing_value': self.missing_value, - 'data_types': [heading.to_jvalue() for heading in self.headings]}} + return { + 'Emit': { + 'table': + self.table, + 'headings': [heading.to_jvalue() for heading in self.headings], + 'source': + self.source.to_jvalue(), + 'missing_value': + self.missing_value, + 'data_types': [ + heading.to_jvalue() for heading in self.headings + ] + } + } def __eq__(self, other): return ( isinstance(other, Emit) and self.table == other.table - and self.headings == other.headings - and self.source == other.source + and self.headings == other.headings and self.source == other.source and self.missing_value == other.missing_value and self.data_types == other.data_types ) def __repr__(self): return '%s(table=%r, headings=%r, source=%r, missing_value=%r)' % ( - self.__class__.__name__, self.table, self.headings, self.source, self.missing_value + self.__class__.__name__, + self.table, + self.headings, + self.source, + self.missing_value ) diff --git a/commcare_export/misc.py b/commcare_export/misc.py index 6858e20b..f399e605 100644 --- a/commcare_export/misc.py +++ b/commcare_export/misc.py @@ -1,17 +1,18 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes import functools import hashlib import inspect import io -from jsonpath_rw import jsonpath + from commcare_export.repeatable_iterator import RepeatableIterator +from jsonpath_ng import jsonpath def digest_file(path): with io.open(path, 'rb') as filehandle: digest = hashlib.md5() while True: - chunk = filehandle.read(4096) # Arbitrary choice of size to be ~filesystem block size friendly + # Arbitrary choice of size to be ~filesystem block size friendly + chunk = filehandle.read(4096) if not chunk: break digest.update(chunk) @@ -21,6 +22,7 @@ def digest_file(path): def unwrap(arg_name): def unwrapper(fn): + @functools.wraps(fn) def _inner(*args): callargs = inspect.getcallargs(fn, *args) diff --git a/commcare_export/repeatable_iterator.py b/commcare_export/repeatable_iterator.py index 17285e44..75b22d3e 100644 --- a/commcare_export/repeatable_iterator.py +++ b/commcare_export/repeatable_iterator.py @@ -1,12 +1,9 @@ -from types import GeneratorType - - class RepeatableIterator(object): """ - Pass something iterable into this and, - unless it has crufty issues, voila. + Pass something iterable into this and, unless it has crufty issues, + voila. """ - + def __init__(self, generator): self.generator = generator self.__val = None diff --git a/commcare_export/specs.py b/commcare_export/specs.py index d5b1e0a6..f4ee0216 100644 --- a/commcare_export/specs.py +++ b/commcare_export/specs.py @@ -1,5 +1,3 @@ - - class TableSpec: def __init__(self, name, headings, rows, data_types=None): @@ -13,7 +11,6 @@ def __eq__(self, other): isinstance(other, TableSpec) and other.name == self.name and other.headings == self.headings - and other.rows == self.rows and other.data_types == self.data_types ) @@ -21,6 +18,5 @@ def toJSON(self): return { 'name': self.name, 'headings': self.headings, - 'rows': self.rows, 'data_types': self.data_types, } diff --git a/commcare_export/utils.py b/commcare_export/utils.py index 74cfceda..7a7d7e4b 100644 --- a/commcare_export/utils.py +++ b/commcare_export/utils.py @@ -1,11 +1,8 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import sys from commcare_export import misc from commcare_export.checkpoint import CheckpointManager -from six.moves import input - +from commcare_export.specs import TableSpec from commcare_export.writers import StreamingMarkdownTableWriter @@ -18,17 +15,19 @@ def get_checkpoint_manager(args, require_query=True): raise return CheckpointManager( - args.output, args.query, md5, - args.project, args.commcare_hq, args.checkpoint_key + args.output, + args.query, + md5, + args.project, + args.commcare_hq, + args.checkpoint_key ) def confirm(message): - confirm = input( - """ - {}? [y/N] - """.format(message) - ) + confirm = input(f""" + {message}? [y/N] + """) return confirm == "y" @@ -37,19 +36,33 @@ def print_runs(runs): rows = [] for run in runs: rows.append([ - run.time_of_run, run.since_param, "True" if run.final else "False", - run.project, run.query_file_name, run.query_file_md5, run.key, run.table_name, run.commcare + run.time_of_run, + run.since_param, + "True" if run.final else "False", + run.project, + run.query_file_name, + run.query_file_md5, + run.key, + run.table_name, + run.commcare ]) - rows = [ - [val if val is not None else '' for val in row] - for row in rows - ] + rows = [[val if val is not None else '' for val in row] for row in rows] - StreamingMarkdownTableWriter(sys.stdout, compute_widths=True).write_table({ - 'headings': [ - "Checkpoint Time", "Batch end date", "Export Complete", - "Project", "Query Filename", "Query MD5", "Key", "Table", "CommCare HQ" + StreamingMarkdownTableWriter( + sys.stdout, compute_widths=True + ).write_table(TableSpec( + name='', + headings=[ + "Checkpoint Time", + "Batch end date", + "Export Complete", + "Project", + "Query Filename", + "Query MD5", + "Key", + "Table", + "CommCare HQ" ], - 'rows': rows - }) + rows=rows, + )) diff --git a/commcare_export/utils_cli.py b/commcare_export/utils_cli.py index b7e19e0d..2543ee63 100644 --- a/commcare_export/utils_cli.py +++ b/commcare_export/utils_cli.py @@ -1,17 +1,13 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import argparse import inspect import logging import sys from commcare_export.cli import CLI_ARGS -from commcare_export.utils import get_checkpoint_manager, confirm, print_runs +from commcare_export.utils import confirm, get_checkpoint_manager, print_runs EXIT_STATUS_ERROR = 1 -logger = logging.getLogger(__name__) - class BaseCommand(object): slug = None @@ -28,14 +24,18 @@ def run(self, args): class ListHistoryCommand(BaseCommand): slug = 'history' help = """List export history. History will be filtered by arguments provided. - - This command only applies when exporting to a SQL database. The command lists - the checkpoints that have been created by the command. + + This command only applies when exporting to a SQL database. The command + lists the checkpoints that have been created by the command. """ @classmethod def add_arguments(cls, parser): - parser.add_argument('--limit', default=10, help="Limit the number of export runs to display") + parser.add_argument( + '--limit', + default=10, + help="Limit the number of export runs to display" + ) parser.add_argument('--output', required=True, help='SQL Database URL') shared_args = {'project', 'query', 'checkpoint_key', 'commcare_hq'} for arg in CLI_ARGS: @@ -64,22 +64,27 @@ class SetKeyCommand(BaseCommand): slug = 'set-checkpoint-key' help = """Set the key for a particular checkpoint. - This command is used to migrate an non-keyed checkpoint to a keyed checkpoint. + This command is used to migrate an non-keyed checkpoint to a keyed + checkpoint. - This is useful if you already have a populated export database and do not wish to trigger - rebuilds after editing the query file. + This is useful if you already have a populated export database and do + not wish to trigger rebuilds after editing the query file. - For example, you've been running the export tool with query file A.xlsx and have a fully populated - database. Now you need to add an extra column to the table but only want to populate it with new data. + For example, you've been running the export tool with query file A.xlsx + and have a fully populated database. Now you need to add an extra column + to the table but only want to populate it with new data. - What you need to do is update your current checkpoint with a key that you can then use when running - the command from now on. + What you need to do is update your current checkpoint with a key that + you can then use when running the command from now on. - $ commcare-export-utils set-key --project X --query A.xlsx --output [SQL URL] --checkpoint-key my-key + $ commcare-export-utils set-key --project X --query A.xlsx \\ + --output [SQL URL] --checkpoint-key my-key Now when you run the export tool in future you can use this key: - $ commcare-export --project X --query A.xlsx --output [SQL URL] --checkpoint-key my-key ... + $ commcare-export --project X --query A.xlsx --output [SQL URL] \\ + --checkpoint-key my-key ... + """ @classmethod @@ -111,7 +116,9 @@ def run(self, args): return print_runs(runs_no_key) - if confirm("Do you want to set the key for this checkpoint to '{}'".format(key)): + if confirm( + f"Do you want to set the key for this checkpoint to '{key}'" + ): for checkpoint in runs_no_key: checkpoint.key = key manager.update_checkpoint(checkpoint) @@ -120,10 +127,7 @@ def run(self, args): print_runs(runs_no_key) -COMMANDS = [ - ListHistoryCommand, - SetKeyCommand -] +COMMANDS = [ListHistoryCommand, SetKeyCommand] def main(argv): @@ -138,21 +142,14 @@ def main(argv): ) command_type.add_arguments(sub) - try: - args = parser.parse_args(argv) - except UnicodeDecodeError: - for arg in argv: - try: - arg.encode('utf-8') - except UnicodeDecodeError: - sys.stderr.write(u"ERROR: Argument '%s' contains unicode characters. " - u"Only ASCII characters are supported.\n" % unicode(arg, 'utf-8')) - sys.exit(1) - - logging.basicConfig(level=logging.WARN, - format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s') - - exit(main_with_args(args)) + args = parser.parse_args(argv) + + logging.basicConfig( + level=logging.WARN, + format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s' + ) + + sys.exit(main_with_args(args)) def main_with_args(args): diff --git a/commcare_export/version.py b/commcare_export/version.py index d4e1d4e1..5f3c3362 100644 --- a/commcare_export/version.py +++ b/commcare_export/version.py @@ -1,7 +1,7 @@ -from __future__ import print_function, unicode_literals import io -import subprocess +import re import os.path +import subprocess __all__ = ['__version__', 'stored_version', 'git_version'] @@ -17,8 +17,29 @@ def stored_version(): def git_version(): - described_version_bytes = subprocess.Popen(['git', 'describe'], stdout=subprocess.PIPE).communicate()[0].strip() - return described_version_bytes.decode('ascii') + if os.environ.get('DET_EXECUTABLE'): + return None + + described_version_bytes = subprocess.Popen( + ['git', 'describe'], + stdout=subprocess.PIPE + ).communicate()[0].strip() + version_raw = described_version_bytes.decode('ascii') + return parse_version(version_raw) + + +def parse_version(version_raw): + """Attempt to convert a git version to a version + compatible with PEP440: https://peps.python.org/pep-0440/ + """ + match = re.match('(\d+\.\d+\.\d+)(?:-(\d+).*)?', version_raw) + if match: + tag_version, lead_count = match.groups() + if lead_count: + tag_version += ".dev{}".format(lead_count) + return tag_version + + return version_raw def version(): diff --git a/commcare_export/writers.py b/commcare_export/writers.py index fd104755..bfeef62b 100644 --- a/commcare_export/writers.py +++ b/commcare_export/writers.py @@ -1,20 +1,19 @@ +import csv import datetime import io -import logging import zipfile -from six.moves import zip_longest +from itertools import zip_longest -import alembic -import csv342 as csv -import six import sqlalchemy -from six import u +from sqlalchemy.exc import NoSuchTableError +from alembic.migration import MigrationContext +from alembic.operations import Operations from commcare_export.data_types import UnknownDataType, get_sqlalchemy_type from commcare_export.specs import TableSpec +from commcare_export import get_logger -logger = logging.getLogger(__name__) - +logger = get_logger(__file__) MAX_COLUMN_SIZE = 2000 @@ -22,49 +21,51 @@ def ensure_text(v, convert_none=False): if v is None: return '' if convert_none else v - if isinstance(v, six.text_type): + if isinstance(v, str): + return v + elif isinstance(v, bytes): return v - elif isinstance(v, six.binary_type): - return u(v) elif isinstance(v, datetime.datetime): return v.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(v, datetime.date): return v.isoformat() else: - return u(str(v)) + return str(v) + def to_jvalue(v): if v is None: return None - if isinstance(v, (six.text_type,) + six.integer_types): + if isinstance(v, (str, int)): + return v + elif isinstance(v, bytes): return v - elif isinstance(v, six.binary_type): - return u(v) else: - return u(str(v)) + return str(v) + class TableWriter(object): """ - Interface for export writers: Usable in a "with" - statement, and while open one can call write_table. + Interface for export writers: Usable in a "with" statement, and + while open one can call write_table. - If the implementing class does not actually need any - set up, no-op defaults have been provided + If the implementing class does not actually need any set up, no-op + defaults have been provided. """ max_column_length = None support_checkpoints = False - # set to False if writer does not support writing to the same table multiple times + # set to False if writer does not support writing to the same table + # multiple times supports_multi_table_write = True required_columns = None def __enter__(self): return self - - def write_table(self, table): - "{'name': str, 'headings': [str], 'rows': [[str]]} -> ()" + + def write_table(self, table: TableSpec) -> None: raise NotImplementedError() def __exit__(self, exc_type, exc_val, exc_tb): @@ -78,7 +79,7 @@ def __init__(self, file, max_column_size=MAX_COLUMN_SIZE): self.file = file self.tables = [] self.archive = None - + def __enter__(self): self.archive = zipfile.ZipFile(self.file, 'w', zipfile.ZIP_DEFLATED) return self @@ -87,22 +88,18 @@ def write_table(self, table): if self.archive is None: raise Exception('Attempt to write to a closed CsvWriter') - def _encode_row(row): - return [ - val.encode('utf-8') if isinstance(val, bytes) else val - for val in row - ] - tempfile = io.StringIO() writer = csv.writer(tempfile, dialect=csv.excel) - writer.writerow(_encode_row(table.headings)) + writer.writerow(table.headings) for row in table.rows: - writer.writerow(_encode_row(row)) + writer.writerow(row) - # TODO: make this a polite zip and put everything in a subfolder with the same basename - # as the zipfile - self.archive.writestr('%s.csv' % self.zip_safe_name(table.name), - tempfile.getvalue().encode('utf-8')) + # TODO: make this a polite zip and put everything in a subfolder + # with the same basename as the zipfile + self.archive.writestr( + '%s.csv' % self.zip_safe_name(table.name), + tempfile.getvalue().encode('utf-8') + ) def __exit__(self, exc_type, exc_val, exc_tb): self.archive.close() @@ -113,14 +110,16 @@ def zip_safe_name(self, name): class Excel2007TableWriter(TableWriter): max_table_name_size = 31 - + def __init__(self, file): try: import openpyxl except ImportError: - raise Exception("It doesn't look like this machine is configured for " - "excel export. To export to excel you have to run the " - "command: pip install openpyxl") + raise Exception( + "It doesn't look like this machine is configured for " + "Excel export. To export to Excel you have to run the " + "command: pip install openpyxl" + ) self.file = file self.book = openpyxl.workbook.Workbook(write_only=True) @@ -155,9 +154,11 @@ def __init__(self, file): try: import xlwt except ImportError: - raise Exception("It doesn't look like this machine is configured for " - "excel export. To export to excel you have to run the " - "command: pip install xlwt") + raise Exception( + "It doesn't look like this machine is configured for " + "excel export. To export to excel you have to run the " + "command: pip install xlwt" + ) self.file = file self.book = xlwt.Workbook() @@ -184,10 +185,10 @@ def get_sheet(self, table): for colnum, val in enumerate(table.headings): sheet.write(0, colnum, ensure_text(val)) - self.sheets[name] = (sheet, 1) # start from row 1 + self.sheets[name] = (sheet, 1) # start from row 1 return self.sheets[name] - + def __exit__(self, exc_type, exc_val, exc_tb): self.book.save(self.file) @@ -199,7 +200,7 @@ class JValueTableWriter(TableWriter): def __init__(self): self.tables = {} - + def write_table(self, table): if table.name not in self.tables: self.tables[table.name] = TableSpec( @@ -210,39 +211,49 @@ def write_table(self, table): else: assert self.tables[table.name].headings == list(table.headings) - self.tables[table.name].rows.extend( - [to_jvalue(v) for v in row] for row in table.rows - ) + self.tables[table.name].rows = list( + self.tables[table.name].rows + ) + [[to_jvalue(v) for v in row] for row in table.rows] class StreamingMarkdownTableWriter(TableWriter): """ - Writes markdown to an output stream, where each table just comes one after the other + Writes markdown to an output stream, where each table just comes one + after the other """ supports_multi_table_write = False def __init__(self, output_stream, compute_widths=False): self.output_stream = output_stream self.compute_widths = compute_widths - - def write_table(self, table, ): + + def write_table(self, table): col_widths = None if self.compute_widths: col_widths = self._get_column_widths(table) - row_template = ' | '.join(['{{:<{}}}'.format(width) for width in col_widths]) + row_template = ' | '.join([ + '{{:<{}}}'.format(width) for width in col_widths + ]) else: row_template = ' | '.join(['{}'] * len(table.headings)) if table.name: self.output_stream.write('\n# %s \n\n' % table.name) - self.output_stream.write('| %s |\n' % row_template.format(*table.headings)) + self.output_stream.write( + '| %s |\n' % row_template.format(*table.headings) + ) if col_widths: - self.output_stream.write('| %s |\n' % row_template.format(*['-' * width for width in col_widths])) + self.output_stream.write( + '| %s |\n' + % row_template.format(*['-' * width for width in col_widths]) + ) for row in table.rows: text_row = (ensure_text(val, convert_none=True) for val in row) - self.output_stream.write('| %s |\n' % row_template.format(*text_row)) + self.output_stream.write( + '| %s |\n' % row_template.format(*text_row) + ) def _get_column_widths(self, table): all_rows = [table.headings] + table.rows @@ -258,12 +269,17 @@ class SqlMixin(object): """ MIN_VARCHAR_LEN = 32 - MAX_VARCHAR_LEN = 255 # Arbitrary point at which we switch to TEXT; for postgres VARCHAR == TEXT anyhow + # Arbitrary point at which we switch to TEXT; for Postgres + # VARCHAR == TEXT anyhow + MAX_VARCHAR_LEN = 255 def __init__(self, db_url, poolclass=None, engine=None): self.db_url = db_url - self.collation = 'utf8_bin' if 'mysql' in db_url else None - self.engine = engine or sqlalchemy.create_engine(db_url, poolclass=poolclass) + self.collation = 'utf8mb4_unicode_ci' if 'mysql' in db_url else None + self.engine = engine or sqlalchemy.create_engine( + db_url, poolclass=poolclass + ) + self._metadata = None def __enter__(self): self.connection = self.engine.connect() @@ -305,24 +321,31 @@ def max_column_length(self): @property def metadata(self): - if not hasattr(self, '_metadata') or self._metadata.bind.closed or self._metadata.bind.invalidated: + if ( + self._metadata is None + or self._metadata.bind.closed + or self._metadata.bind.invalidated + ): if self.connection.closed: - raise Exception('Tried to reflect via a closed connection') + raise Exception('Tried to bind to a closed connection') if self.connection.invalidated: - raise Exception('Tried to reflect via an invalidated connection') - self._metadata = sqlalchemy.MetaData() - self._metadata.bind = self.connection - self._metadata.reflect() + raise Exception('Tried to bind to an invalidated connection') + self._metadata = sqlalchemy.MetaData(bind=self.connection) return self._metadata - def table(self, table_name): - return sqlalchemy.Table(table_name, self.metadata, autoload=True, autoload_with=self.connection) + def get_table(self, table_name): + try: + return sqlalchemy.Table( + table_name, + self.metadata, + autoload_with=self.connection, + ) + except NoSuchTableError: + return None def get_id_column(self): return sqlalchemy.Column( - 'id', - sqlalchemy.Unicode(self.MAX_VARCHAR_LEN), - primary_key=True + 'id', sqlalchemy.Unicode(self.MAX_VARCHAR_LEN), primary_key=True ) @@ -348,41 +371,63 @@ def get_explicit_type(self, data_type): return get_sqlalchemy_type(data_type) except UnknownDataType: if data_type: - logger.warning("Found unknown data type '{data_type}'".format( - data_type=data_type, - )) + logger.warning( + "Found unknown data type '{data_type}'".format( + data_type=data_type, + ) + ) return self.best_type_for('') # todo: more explicit fallback def best_type_for(self, val): if isinstance(val, bool): return sqlalchemy.Boolean() elif isinstance(val, datetime.datetime): - return sqlalchemy.DateTime() + if self.is_mssql: + return sqlalchemy.dialects.mssql.DATETIME2() + else: + return sqlalchemy.DateTime() elif isinstance(val, datetime.date): return sqlalchemy.Date() if isinstance(val, int): return sqlalchemy.Integer() - elif isinstance(val, six.string_types): + elif isinstance(val, str): if self.is_postgres: - # PostgreSQL is the best; you can use TEXT everywhere and it works like a charm. + # PostgreSQL is the best; you can use TEXT everywhere + # and it works like a charm. return sqlalchemy.UnicodeText(collation=self.collation) elif self.is_mysql: - # MySQL cannot build an index on TEXT due to the lack of a field length, so we - # try to use VARCHAR when possible. - if len(val) < self.MAX_VARCHAR_LEN: # FIXME: Is 255 an interesting cutoff? - return sqlalchemy.Unicode(max(len(val), self.MIN_VARCHAR_LEN), collation=self.collation) + # MySQL cannot build an index on TEXT due to the lack of + # a field length, so we try to use VARCHAR when + # possible. + if len(val) < self.MAX_VARCHAR_LEN: + return sqlalchemy.Unicode( + max(len(val), self.MIN_VARCHAR_LEN), + collation=self.collation + ) else: return sqlalchemy.UnicodeText(collation=self.collation) elif self.is_mssql: - return sqlalchemy.NVARCHAR(collation=self.collation) - if self.is_oracle: + # MSSQL (pre 2016) doesn't allow indices on columns + # longer than 900 bytes + # https://docs.microsoft.com/en-us/sql/t-sql/statements/create-index-transact-sql + # If any of our data is bigger than this, then set the + # column to NVARCHAR(max) `length` here is the size in + # bytes + # https://docs.sqlalchemy.org/en/13/core/type_basics.html#sqlalchemy.types.String.params.length + length_in_bytes = len(val.encode('utf-8')) + column_length_in_bytes = None if length_in_bytes > 900 else 900 + return sqlalchemy.NVARCHAR( + length=column_length_in_bytes, collation=self.collation + ) + elif self.is_oracle: return sqlalchemy.Unicode(4000, collation=self.collation) else: - raise Exception("Unknown database dialect: {}".format(self.db_url)) + raise Exception(f"Unknown database dialect: {self.db_url}") else: - # We do not have a name for "bottom" in SQL aka the type whose least upper bound - # with any other type is the other type. + # We do not have a name for "bottom" in SQL aka the type + # whose least upper bound with any other type is the other + # type. return sqlalchemy.UnicodeText(collation=self.collation) def compatible(self, source_type, dest_type): @@ -393,25 +438,36 @@ def compatible(self, source_type, dest_type): if not isinstance(dest_type, sqlalchemy.String): return False elif source_type.length is None: - # The length being None means that we are looking at indefinite strings aka TEXT. - # This tool will never create strings with bounds, but if a target DB has one then - # we cannot insert to it. - # We will request that whomever uses this tool convert to TEXT type. + # The length being None means that we are looking at + # indefinite strings aka TEXT. This tool will never + # create strings with bounds, but if a target DB has one + # then we cannot insert to it. We will request that + # whoever uses this tool convert to TEXT type. return dest_type.length is None else: - return dest_type.length is None or (dest_type.length >= source_type.length) + return dest_type.length is None or ( + dest_type.length >= source_type.length + ) compatibility = { sqlalchemy.String: (sqlalchemy.Text,), sqlalchemy.Integer: (sqlalchemy.String, sqlalchemy.Text), - sqlalchemy.Boolean: (sqlalchemy.String, sqlalchemy.Text, sqlalchemy.Integer), - sqlalchemy.DateTime: (sqlalchemy.String, sqlalchemy.Text, sqlalchemy.Date), + sqlalchemy.Boolean: + (sqlalchemy.String, sqlalchemy.Text, sqlalchemy.Integer), + sqlalchemy.DateTime: + (sqlalchemy.String, sqlalchemy.Text, sqlalchemy.Date), sqlalchemy.Date: (sqlalchemy.String, sqlalchemy.Text), } # add dialect specific types try: - compatibility[sqlalchemy.Boolean] += (sqlalchemy.dialects.mssql.base.BIT,) + compatibility[sqlalchemy.JSON + ] = (sqlalchemy.dialects.postgresql.json.JSON,) + except AttributeError: + pass + try: + compatibility[sqlalchemy.Boolean + ] += (sqlalchemy.dialects.mssql.base.BIT,) except AttributeError: pass @@ -425,7 +481,7 @@ def strict_types_compatibility_check(self, source_type, dest_type, val): return # Can't do anything elif dest_type.length is None: return # already a TEXT column - elif isinstance(val, six.string_types) and dest_type.length >= len(val): + elif isinstance(val, str) and dest_type.length >= len(val): return # no need to upgrade to TEXT column elif source_type.length is None: return sqlalchemy.UnicodeText(collation=self.collation) @@ -434,101 +490,139 @@ def strict_types_compatibility_check(self, source_type, dest_type, val): def least_upper_bound(self, source_type, dest_type): """ - Returns the _coercion_ least uppper bound. + Returns the _coercion_ least upper bound. + Mostly just promotes everything to string if it is not already. - In fact, since this is only called when they are incompatible, it promotes to string right away. + In fact, since this is only called when they are incompatible, + it promotes to string right away. """ # FIXME: Don't be so silly return sqlalchemy.UnicodeText(collation=self.collation) - def make_table_compatible(self, table_name, row_dict, data_type_dict): - ctx = alembic.migration.MigrationContext.configure(self.connection) - op = alembic.operations.Operations(ctx) - - if not table_name in self.metadata.tables: - if self.strict_types: - create_sql = sqlalchemy.schema.CreateTable(sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(), - *self._get_columns_for_data(row_dict, data_type_dict) - )).compile(self.connection.engine) - logger.warning("Table '{table_name}' does not exist. Creating table with:\n{schema}".format( - table_name=table_name, - schema=create_sql - )) - empty_cols = [name for name, val in row_dict.items() - if val is None and name not in data_type_dict] - if empty_cols: - logger.warning("This schema does not include the following columns since we are unable " - "to determine the column type at this stage: {}".format(empty_cols)) - op.create_table(table_name, *self._get_columns_for_data(row_dict, data_type_dict)) - self.metadata.clear() - self.metadata.reflect() - return - - def get_current_table_columns(): - return {c.name: c for c in self.table(table_name).columns} - - columns = get_current_table_columns() + def make_table_compatible(self, table, row_dict, data_type_dict): + ctx = MigrationContext.configure(self.connection) + op = Operations(ctx) + columns = {c.name: c for c in table.columns} for column, val in row_dict.items(): if val is None: continue - ty = self.get_data_type(data_type_dict[column], val) - if not column in columns: - logger.warning("Adding column '{}.{} {}'".format(table_name, column, ty)) - op.add_column(table_name, sqlalchemy.Column(column, ty, nullable=True)) + val_type = self.get_data_type(data_type_dict[column], val) + if column not in columns: + logger.warning( + f"Adding column '{table.name}.{column} {val_type}'" + ) + op.add_column( + table.name, + sqlalchemy.Column(column, val_type, nullable=True) + ) self.metadata.clear() - self.metadata.reflect() - columns = get_current_table_columns() + table = self.get_table(table.name) elif not columns[column].primary_key: - current_ty = columns[column].type - new_type = None + col_type = columns[column].type + new_col_type = None if self.strict_types: - # don't bother checking compatibility since we're not going to change anything - new_type = self.strict_types_compatibility_check(ty, current_ty, val) - elif not self.compatible(ty, current_ty): - new_type = self.least_upper_bound(ty, current_ty) - - if new_type: - logger.warning('Altering column %s from %s to %s for value: "%s:%s"', columns[column], current_ty, new_type, type(val), val) - op.alter_column(table_name, column, type_=new_type) + # don't bother checking compatibility since we're + # not going to change anything + new_col_type = self.strict_types_compatibility_check( + val_type, col_type, val + ) + elif not self.compatible(val_type, col_type): + new_col_type = self.least_upper_bound(val_type, col_type) + + if new_col_type: + logger.warning( + f'Altering column {columns[column]} from {col_type} ' + f'to {new_col_type} for value: "{type(val)}:{val}"', + ) + op.alter_column(table.name, column, type_=new_col_type) self.metadata.clear() - self.metadata.reflect() - columns = get_current_table_columns() + table = self.get_table(table.name) + return table + + def create_table(self, table_name, row_dict, data_type_dict): + ctx = MigrationContext.configure(self.connection) + op = Operations(ctx) + if self.strict_types: + create_sql = sqlalchemy.schema.CreateTable( + sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(), + *self._get_columns_for_data(row_dict, data_type_dict) + ) + ).compile(self.connection.engine) + logger.warning( + f"Table '{table_name}' does not exist. Creating table " + f"with:\n{create_sql}" + ) + empty_cols = [ + name for (name, val) in row_dict.items() + if val is None and name not in data_type_dict + ] + if empty_cols: + logger.warning( + "This schema does not include the following columns " + "since we are unable to determine the column type at " + f"this stage: {empty_cols}" + ) + op.create_table( + table_name, + *self._get_columns_for_data(row_dict, data_type_dict) + ) + self.metadata.clear() + return self.get_table(table_name) def upsert(self, table, row_dict): - # For atomicity "insert, catch, update" is slightly better than "select, insert or update". - # The latter may crash, while the former may overwrite data (which should be fine if whatever is - # racing against this is importing from the same source... if not you are busted anyhow - - # strip out values that are None since the column may not exist yet - row_dict = {col: val for col, val in row_dict.items() if val is not None} + # For atomicity "insert, catch, update" is slightly better than + # "select, insert or update". The latter may crash, while the + # former may overwrite data (which should be fine if whatever is + # racing against this is importing from the same source... if + # not you are busted anyhow + + # strip out values that are None since the column may not exist + # yet + row_dict = { + col: val for col, val in row_dict.items() if val is not None + } try: insert = table.insert().values(**row_dict) self.connection.execute(insert) except sqlalchemy.exc.IntegrityError: - update = table.update().where(table.c.id == row_dict['id']).values(**row_dict) + update = (table.update() + .where(table.c.id == row_dict['id']) + .values(**row_dict)) self.connection.execute(update) - def write_table(self, table): - """ - :param table: a TableSpec - """ - table_name = table.name - headings = table.headings - data_type_dict = dict(zip_longest(headings, table.data_types)) - # Rather inefficient for now... - for row in table.rows: + def write_table(self, table_spec: TableSpec) -> None: + table_name = table_spec.name + headings = table_spec.headings + data_type_dict = dict(zip_longest(headings, table_spec.data_types)) + for i, row in enumerate(table_spec.rows): row_dict = dict(zip(headings, row)) - self.make_table_compatible(table_name, row_dict, data_type_dict) - self.upsert(self.table(table_name), row_dict) + if i == 0: + table = self.get_table(table_name) + if table is None: + table = self.create_table( + table_name, + row_dict, + data_type_dict, + ) + # Checks the data type for every cell in every row. Maybe we + # can use a future version of the data dictionary to avoid + # this? + table = self.make_table_compatible(table, row_dict, data_type_dict) + self.upsert(table, row_dict) def _get_columns_for_data(self, row_dict, data_type_dict): return [self.get_id_column()] + [ - sqlalchemy.Column(column_name, self.get_data_type(data_type_dict[column_name], val), nullable=True) - for column_name, val in row_dict.items() - if (val is not None or data_type_dict[column_name]) and column_name != 'id' + sqlalchemy.Column( + column_name, + self.get_data_type(data_type_dict[column_name], val), + nullable=True + ) + for (column_name, val) in row_dict.items() + if ((val is not None or data_type_dict[column_name]) + and column_name != 'id') ] diff --git a/examples/demo-deliveries.json b/examples/demo-deliveries.json index 77e15078..5ac64320 100644 --- a/examples/demo-deliveries.json +++ b/examples/demo-deliveries.json @@ -7,6 +7,9 @@ "Ref": "get_checkpoint_manager" }, "args": [ + { + "Lit": "form" + }, { "Lit": [ "Deliveries", diff --git a/examples/demo-pregnancy-cases-with-forms.json b/examples/demo-pregnancy-cases-with-forms.json index 089e0760..62069f51 100644 --- a/examples/demo-pregnancy-cases-with-forms.json +++ b/examples/demo-pregnancy-cases-with-forms.json @@ -9,6 +9,9 @@ "Ref": "get_checkpoint_manager" }, "args": [ + { + "Lit": "case" + }, { "Lit": [ "Pregnant Mother Cases" @@ -112,6 +115,9 @@ "Ref": "get_checkpoint_manager" }, "args": [ + { + "Lit": "case" + }, { "Lit": [ "CaseToForm" diff --git a/examples/demo-pregnancy-cases.json b/examples/demo-pregnancy-cases.json index e65144eb..68bce7b5 100644 --- a/examples/demo-pregnancy-cases.json +++ b/examples/demo-pregnancy-cases.json @@ -7,6 +7,9 @@ "Ref": "get_checkpoint_manager" }, "args": [ + { + "Lit": "case" + }, { "Lit": [ "Pregnant Mother Cases", diff --git a/examples/demo-registrations.json b/examples/demo-registrations.json index 1c9742fb..01511912 100644 --- a/examples/demo-registrations.json +++ b/examples/demo-registrations.json @@ -7,6 +7,9 @@ "Ref": "get_checkpoint_manager" }, "args": [ + { + "Lit": "form" + }, { "Lit": [ "Registrations" diff --git a/examples/scheduled_run_linux.sh b/examples/scheduled_run_linux.sh new file mode 100644 index 00000000..b485f0fc --- /dev/null +++ b/examples/scheduled_run_linux.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +commcare-export --output-format \ + --output \ + --query \ + --project \ + --commcare-hq https://commcarehq.org \ + --auth-mode apikey \ + --password diff --git a/examples/scheduled_run_windows.bat b/examples/scheduled_run_windows.bat new file mode 100644 index 00000000..2ed1c2b5 --- /dev/null +++ b/examples/scheduled_run_windows.bat @@ -0,0 +1,7 @@ +commcare-export --output-format ^ + --output ^ + --query ^ + --project ^ + --commcare-hq https://commcarehq.org ^ + --auth-mode apikey ^ + --password diff --git a/migrations/versions/3b37b3b06104_added_cursor_to_checkpoint.py b/migrations/versions/3b37b3b06104_added_cursor_to_checkpoint.py new file mode 100644 index 00000000..2e7c05d9 --- /dev/null +++ b/migrations/versions/3b37b3b06104_added_cursor_to_checkpoint.py @@ -0,0 +1,27 @@ +"""Added cursor to checkpoint + +Revision ID: 3b37b3b06104 +Revises: 6f158d161ab6 +Create Date: 2023-08-25 11:10:38.713189 + +""" +from alembic import op +import sqlalchemy as sa + + +revision = '3b37b3b06104' +down_revision = '6f158d161ab6' +branch_labels = None +depends_on = None + + +def upgrade(): + url = op.get_bind().engine.url + collation = 'utf8_bin' if 'mysql' in url.drivername else None + op.add_column( + 'commcare_export_runs', + sa.Column('cursor', sa.Unicode(255, collation=collation)) + ) + +def downgrade(): + op.drop_column('commcare_export_runs', 'cursor') diff --git a/migrations/versions/6f158d161ab6_add_pagination_mode_to_checkpoint.py b/migrations/versions/6f158d161ab6_add_pagination_mode_to_checkpoint.py new file mode 100644 index 00000000..13445ae7 --- /dev/null +++ b/migrations/versions/6f158d161ab6_add_pagination_mode_to_checkpoint.py @@ -0,0 +1,28 @@ +"""Add pagination_mode to checkpoint + +Revision ID: 6f158d161ab6 +Revises: a56c82a8d02e +Create Date: 2021-01-25 15:13:45.996453 + +""" +from alembic import op +import sqlalchemy as sa + + +revision = '6f158d161ab6' +down_revision = 'a56c82a8d02e' +branch_labels = None +depends_on = None + + + +def upgrade(): + url = op.get_bind().engine.url + collation = 'utf8_bin' if 'mysql' in url.drivername else None + op.add_column( + 'commcare_export_runs', + sa.Column('pagination_mode', sa.Unicode(255, collation=collation)) + ) + +def downgrade(): + op.drop_column('commcare_export_runs', 'pagination_mode') diff --git a/migrations/versions/a56c82a8d02e_add_detail_to_checkpoint.py b/migrations/versions/a56c82a8d02e_add_detail_to_checkpoint.py new file mode 100644 index 00000000..2ba1b149 --- /dev/null +++ b/migrations/versions/a56c82a8d02e_add_detail_to_checkpoint.py @@ -0,0 +1,32 @@ +"""Add detail to checkpoint + +Revision ID: a56c82a8d02e +Revises: f4fd4c80f40a +Create Date: 2021-01-22 16:35:07.063082 + +""" +from alembic import op +import sqlalchemy as sa + + +revision = 'a56c82a8d02e' +down_revision = 'f4fd4c80f40a' +branch_labels = None +depends_on = None + + +def upgrade(): + url = op.get_bind().engine.url + collation = 'utf8_bin' if 'mysql' in url.drivername else None + op.add_column( + 'commcare_export_runs', + sa.Column('data_source', sa.Unicode(255, collation=collation)) + ) + op.add_column( + 'commcare_export_runs', + sa.Column('last_doc_id', sa.Unicode(255, collation=collation)) + ) + +def downgrade(): + op.drop_column('commcare_export_runs', 'data_source') + op.drop_column('commcare_export_runs', 'last_doc_id') diff --git a/migrations/versions/c36489c5a628_create_commcare_export_runs.py b/migrations/versions/c36489c5a628_create_commcare_export_runs.py index 302b1b13..15660170 100644 --- a/migrations/versions/c36489c5a628_create_commcare_export_runs.py +++ b/migrations/versions/c36489c5a628_create_commcare_export_runs.py @@ -18,6 +18,7 @@ def upgrade(): meta = sa.MetaData(bind=op.get_bind()) meta.reflect() + if 'commcare_export_runs' not in meta.tables: url = op.get_bind().engine.url collation = 'utf8_bin' if 'mysql' in url.drivername else None diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..94ebe3c5 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,27 @@ +[mypy] +python_version = 3.9 +follow_imports = silent + +# TODO: Get or create stubs for libraries. Until then: +ignore_missing_imports = True + +# Typing strictness +check_untyped_defs = True +disallow_subclassing_any = True +disallow_any_generics = True +warn_return_any = True +strict_equality = True +# Set "disallow_untyped_defs = True" for completely typed modules in +# per-module options below + +# Check for drift +warn_redundant_casts = True +warn_unused_ignores = True +warn_unused_configs = True + +# Non-typing checks +implicit_reexport = False +warn_unreachable = True + +# Reporting +show_error_codes = True diff --git a/mypy_typed_modules.txt b/mypy_typed_modules.txt new file mode 100644 index 00000000..d2155693 --- /dev/null +++ b/mypy_typed_modules.txt @@ -0,0 +1,2 @@ +commcare_export/env.py +commcare_export/minilinq.py diff --git a/reports.zip b/reports.zip new file mode 100644 index 00000000..15cb0ecb Binary files /dev/null and b/reports.zip differ diff --git a/setup.py b/setup.py index c56246ad..ae379a87 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,13 @@ -from __future__ import print_function -import os.path -import sys import glob -import re import io -import subprocess -import setuptools +import os.path +import re +import sys + +import setuptools from setuptools.command.test import test as TestCommand -VERSION_PATH='commcare_export/VERSION' +VERSION_PATH = 'commcare_export/VERSION' # Overwrite VERSION if we are actually building for a distribution to pypi # This code path requires dependencies, etc, to be available @@ -19,21 +18,24 @@ # This import requires either commcare_export/VERSION or to be in a git clone (as does the package in general) import commcare_export + version = commcare_export.version.version() # Crash if the VERSION is not a simple version and it is going to register or upload if 'register' in sys.argv or 'upload' in sys.argv: version = commcare_export.version.stored_version() - if not version or not re.match('\d+\.\d+\.\d+', version): - print('Version %s is not an appropriate version for publicizing!' % version) + if not version or not re.match(r'\d+\.\d+\.\d+', version): + print('Version %s is not an appropriate version for publicizing!' % + version) sys.exit(1) readme = 'README.md' + class PyTest(TestCommand): def finalize_options(self): TestCommand.finalize_options(self) - self.test_args = ['-vv'] + self.test_args = ['-vv', '--tb=short'] self.test_suite = True def run_tests(self): @@ -44,47 +46,60 @@ def run_tests(self): test_deps = ['pytest', 'psycopg2', 'mock'] +base_sql_deps = ["SQLAlchemy", "alembic"] +postgres = ["psycopg2"] +mysql = ["pymysql"] +odbc = ["pyodbc"] -setuptools.setup( - name = "commcare-export", - version = version, - description = 'A command-line tool (and Python library) to extract data from CommCareHQ into a SQL database or Excel workbook', - long_description = io.open(readme, encoding='utf-8').read(), - long_description_content_type = 'text/markdown', - author = 'Dimagi', - author_email = 'information@dimagi.com', - url = "https://github.com/dimagi/commcare-export", - entry_points = { +setuptools.setup( + name="commcare-export", + version=version, + description='A command-line tool (and Python library) to extract data from ' + 'CommCare HQ into a SQL database or Excel workbook', + long_description=io.open(readme, encoding='utf-8').read(), + long_description_content_type='text/markdown', + author='Dimagi', + author_email='information@dimagi.com', + url="https://github.com/dimagi/commcare-export", + entry_points={ 'console_scripts': [ 'commcare-export = commcare_export.cli:entry_point', 'commcare-export-utils = commcare_export.utils_cli:entry_point' ] }, - packages = setuptools.find_packages(exclude=['tests*']), - data_files = [ - (os.path.join('share', 'commcare-export', 'examples'), glob.glob('examples/*.json') + glob.glob('examples/*.xlsx')), + packages=setuptools.find_packages(exclude=['tests*']), + data_files=[ + (os.path.join('share', 'commcare-export', 'examples'), + glob.glob('examples/*.json') + glob.glob('examples/*.xlsx')), ], include_package_data=True, - license = 'MIT', - install_requires = [ + license='MIT', + python_requires=">=3.6", + install_requires=[ 'alembic', 'argparse', - 'jsonpath-rw>=1.2.1', + 'backoff>=2.0', + 'jsonpath-ng~=1.6.0', + 'ndg-httpsclient', 'openpyxl==2.5.12', 'python-dateutil', + 'pytz', 'requests', - 'ndg-httpsclient', 'simplejson', - 'six', - 'sqlalchemy', - 'pytz', - 'sqlalchemy-migrate', - 'backoff', - 'csv342' + 'sqlalchemy~=1.4', + 'sqlalchemy-migrate' ], - tests_require = test_deps, - cmdclass = {'test': PyTest}, - classifiers = [ + extras_require={ + 'test': test_deps, + 'base_sql': base_sql_deps, + 'postgres': base_sql_deps + postgres, + 'mysql': base_sql_deps + mysql, + 'odbc': base_sql_deps + odbc, + 'xlsx': ["openpyxl"], + 'xls': ["xlwt"], + }, + cmdclass={'test': PyTest}, + classifiers=[ 'Development Status :: 4 - Beta', 'Environment :: Console', 'Intended Audience :: Developers', @@ -93,15 +108,15 @@ def run_tests(self): 'Intended Audience :: System Administrators', 'Intended Audience :: End Users/Desktop', 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.6', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Database', 'Topic :: Software Development :: Interpreters', 'Topic :: System :: Archiving', 'Topic :: System :: Distributed Computing', - ], - extras_require={'test': test_deps} + ] ) diff --git a/tests/003_DataSourceAndEmitColumns.xlsx b/tests/003_DataSourceAndEmitColumns.xlsx index d6962b56..7df68cce 100644 Binary files a/tests/003_DataSourceAndEmitColumns.xlsx and b/tests/003_DataSourceAndEmitColumns.xlsx differ diff --git a/tests/009_expected_form_data.csv b/tests/009_expected_form_data.csv index 14960175..d985e115 100644 --- a/tests/009_expected_form_data.csv +++ b/tests/009_expected_form_data.csv @@ -1,28 +1,27 @@ -id,name,received_on,server_modified_on -3a8776b3-b613-465f-8d2c-431972597222,Sheel,2012-04-24T05:13:01.000000Z,2012-04-24T05:13:01.000000Z -e56abced-bf46-4739-af88-0ec644645b9b,Michel ,2012-04-25T07:02:09.000000Z,2012-04-25T07:02:09.000000Z -4bbd52c6-cef7-41d7-aec8-4a4050c47897,Dionisia,2012-04-25T14:07:05.000000Z,2012-04-25T14:07:05.000000Z -5aa938cc-bade-41d3-baf1-a18b72c0d844,Michel-2.1,2012-04-27T10:05:55.000000Z,2012-04-27T10:05:55.000000Z -674d4fd0-a3df-4b9c-87cd-b6756886581d,Mauro-1,2012-05-02T08:26:09.000000Z,2012-05-02T08:26:09.000000Z -8f209da8-2a4b-4470-86bb-1ae5afcb32d1,Santos,2012-05-10T15:41:32.000000Z,2012-05-10T15:41:32.000000Z -24125f93-67c1-4b91-9988-38588bdead1c,JM,2012-07-13T11:37:22.000000Z,2012-07-13T11:37:22.000000Z -6bd013f2-e549-46b2-ab55-951e90d7cf0d,EUCLIDES1,2012-07-30T10:49:47.000000Z,2012-07-30T10:49:47.000000Z -d6363916-9e54-44d6-b04f-dd34fcb2d0cd,EUCLIDES SYNC 2,2012-07-30T11:07:31.000000Z,2012-07-30T11:07:31.000000Z -0a25740f-c733-4372-a2ea-bb15cac5076c,ECC SYNC 2,2012-07-30T11:07:44.000000Z,2012-07-30T11:07:44.000000Z -43670a3a-d038-4ece-8840-5521a75d2028,ECC SYNC 3,2012-07-30T11:07:51.000000Z,2012-07-30T11:07:51.000000Z -ca54d9e1-d2e8-48ef-b670-1e6df7c49cb8,ECC SYNC 4,2012-07-30T11:10:09.000000Z,2012-07-30T11:10:09.000000Z -1d21a76c-42a0-489b-839f-1c265e6df791,ECC SYNC 2,2012-07-30T11:10:13.000000Z,2012-07-30T11:10:13.000000Z -e2a86f5e-a074-4996-ab2d-c51322451c0a,ECC OTA3,2012-07-30T13:23:58.000000Z,2012-07-30T13:23:58.000000Z -53fd70c3-b79f-4876-b678-f73a5f0dbbe1,ECC OTA2,2012-07-30T14:59:52.000000Z,2012-07-30T14:59:52.000000Z -450d2636-1ebe-4f16-a697-27a759788916,ECC7,2012-07-31T14:49:57.000000Z,2012-07-31T14:49:57.000000Z -54218afe-f32a-49a6-ab0a-35b4673d4887,Euclides 2,2012-08-02T10:15:39.000000Z,2012-08-02T10:15:39.000000Z -40568357-5350-4215-9d51-a646b4d64b70,Euclides Carlos,2012-08-02T10:53:35.000000Z,2012-08-02T10:53:35.000000Z -0c87ad90-2ba0-4806-87a5-daf3473cf829,Euclidez,2012-08-07T05:27:53.000000Z,2012-08-07T05:27:53.000000Z -bbc758fd-ebc9-47b1-9a4f-cc144c0e5839,Euclidez 2,2012-08-07T05:49:04.000000Z,2012-08-07T05:49:04.000000Z -0319a0c3-705c-471d-a164-36e63692b4af,Euclidez new case,2012-08-11T09:03:51.000000Z,2012-08-11T09:03:51.000000Z -b2d58a2e-8c59-4c6e-b221-8e6d23e3dc97,Euclidez new case 2,2012-08-11T09:14:56.000000Z,2012-08-11T09:14:56.000000Z -a31102dc-793a-4f27-b967-d4dfca86b4a6,Guy Mabota,2012-08-16T16:23:40.000000Z,2012-08-16T16:23:40.000000Z -7a20c373-fc6b-4437-ab3a-73d3e4a64be1,EUCLIDES CARLOS,2012-08-26T08:52:37.000000Z,2012-08-26T08:52:37.000000Z -5d39ec4a-7217-48ba-8484-78f8b08124a9,CARLOS,2012-08-26T09:05:57.000000Z,2012-08-26T09:05:57.000000Z -4c4b5ad7-9642-46f2-9947-4a5a5f07973c,MABOTA,2012-08-26T09:07:18.000000Z,2012-08-26T09:07:18.000000Z -6cfb2b4a-6994-415e-bd51-4c25d54628d2,GUY,2012-08-27T06:14:41.000000Z,2012-08-27T06:14:41.000000Z +id,name,inserted_at +722cd0f0-df75-44fe-9f3b-eafdde749556,Register Woman,2017-08-21T15:11:03.897195 +00d52675-48de-4453-a0ab-bf5140e91529,Date Picker Field List OQPS,2017-08-21T22:08:08.903432 +7a30381b-072d-43d5-9a8c-d917c18c4ed0,Date Picker Field List OQPS,2017-08-21T23:12:09.212486 +c8db2245-43a5-42bc-bd9a-5637a4e3e3c5,Registration,2017-08-23T00:26:39.476117 +e533dc9b-86ad-4b88-9c68-000485539c84,Registration,2017-08-23T12:40:47.438764 +44849a67-273c-46d8-9f6e-2a1909b2dc64,Date Picker Field List OQPS,2017-08-24T09:25:43.960309 +6debb614-5937-4e72-8d85-79339b183d44,Registration,2017-08-24T13:52:47.079309 +6cf320ce-cb67-4d73-b13f-73d706c25a58,Registration,2017-08-25T04:40:11.724417 +b921ff1d-d8f1-4fa8-8a0d-7bcb89c64900,Registration,2017-08-25T14:21:18.744358 +1bf4c4a2-c74e-433c-861a-12d908ca3b37,Registration,2017-08-27T20:52:06.531792 +a72e6609-c9fc-4929-8316-67cbe7370faa,Registration,2017-08-27T20:52:06.545798 +733c1f60-55e8-40d8-b11e-25e0b60a1bd3,Registration,2017-08-28T01:51:34.108067 +232d1461-97a2-4bf2-a322-bf00bd7631ab,Registration,2017-08-28T02:47:58.986850 +ef936958-063f-4862-9289-ee041154c3c3,Registration,2017-08-29T03:38:54.521605 +47bc1c89-6006-48b1-8e5c-6468f0d7509c,Registration,2017-08-29T03:38:54.656597 +43cf5491-055a-4a8d-839c-10ea888373dd,Registration,2017-08-29T16:40:41.928598 +c5f0eb63-a5d9-4b45-bc30-eb47a25590d0,Registration,2017-08-29T16:40:41.930963 +dfa09ad4-3efd-4a54-a628-25e1ccaded17,Date Picker Field List OQPS,2017-08-29T17:01:45.618568 +13a82b6a-2e82-4e1a-9f17-e57052b622e3,Date Picker Field List OQPS,2017-08-29T17:01:46.310261 +3407042d-4db0-4564-ba4c-f9e64a9fb17f,Date Picker Field List OQPS,2017-08-29T22:34:04.594038 +0bc4799d-ef00-4de8-9061-cca31e1e630e,Registration,2017-08-30T12:31:38.126856 +5ae9056c-a3b0-4b66-a66f-5830523fb6de,Registration,2017-08-30T20:13:18.057229 +d0cf0d73-d453-4245-9baa-d15577618f9f,Registration,2017-08-31T22:56:41.027155 +a6074fd8-9671-444c-ad10-4ea5f14a8b8e,Registration,2017-09-02T09:39:30.272531 +27b6f2e5-0891-4c55-8fd6-51c639c0cd87,Registration,2017-09-02T20:05:35.451835 +3012b9dc-5d1e-410a-aa39-803191e935ac,Registration,2017-09-02T20:05:35.459547 diff --git a/tests/009_integration.xlsx b/tests/009_integration.xlsx index dc8be15a..0b2fcd53 100644 Binary files a/tests/009_integration.xlsx and b/tests/009_integration.xlsx differ diff --git a/tests/009b_expected_form_1_data.csv b/tests/009b_expected_form_1_data.csv index b460d85b..d985e115 100644 --- a/tests/009b_expected_form_1_data.csv +++ b/tests/009b_expected_form_1_data.csv @@ -1,5 +1,27 @@ -id,name,received_on,server_modified_on -3a8776b3-b613-465f-8d2c-431972597222,Sheel,2012-04-24T05:13:01.000000Z,2012-04-24T05:13:01.000000Z -e56abced-bf46-4739-af88-0ec644645b9b,Michel ,2012-04-25T07:02:09.000000Z,2012-04-25T07:02:09.000000Z -4bbd52c6-cef7-41d7-aec8-4a4050c47897,Dionisia,2012-04-25T14:07:05.000000Z,2012-04-25T14:07:05.000000Z -5aa938cc-bade-41d3-baf1-a18b72c0d844,Michel-2.1,2012-04-27T10:05:55.000000Z,2012-04-27T10:05:55.000000Z +id,name,inserted_at +722cd0f0-df75-44fe-9f3b-eafdde749556,Register Woman,2017-08-21T15:11:03.897195 +00d52675-48de-4453-a0ab-bf5140e91529,Date Picker Field List OQPS,2017-08-21T22:08:08.903432 +7a30381b-072d-43d5-9a8c-d917c18c4ed0,Date Picker Field List OQPS,2017-08-21T23:12:09.212486 +c8db2245-43a5-42bc-bd9a-5637a4e3e3c5,Registration,2017-08-23T00:26:39.476117 +e533dc9b-86ad-4b88-9c68-000485539c84,Registration,2017-08-23T12:40:47.438764 +44849a67-273c-46d8-9f6e-2a1909b2dc64,Date Picker Field List OQPS,2017-08-24T09:25:43.960309 +6debb614-5937-4e72-8d85-79339b183d44,Registration,2017-08-24T13:52:47.079309 +6cf320ce-cb67-4d73-b13f-73d706c25a58,Registration,2017-08-25T04:40:11.724417 +b921ff1d-d8f1-4fa8-8a0d-7bcb89c64900,Registration,2017-08-25T14:21:18.744358 +1bf4c4a2-c74e-433c-861a-12d908ca3b37,Registration,2017-08-27T20:52:06.531792 +a72e6609-c9fc-4929-8316-67cbe7370faa,Registration,2017-08-27T20:52:06.545798 +733c1f60-55e8-40d8-b11e-25e0b60a1bd3,Registration,2017-08-28T01:51:34.108067 +232d1461-97a2-4bf2-a322-bf00bd7631ab,Registration,2017-08-28T02:47:58.986850 +ef936958-063f-4862-9289-ee041154c3c3,Registration,2017-08-29T03:38:54.521605 +47bc1c89-6006-48b1-8e5c-6468f0d7509c,Registration,2017-08-29T03:38:54.656597 +43cf5491-055a-4a8d-839c-10ea888373dd,Registration,2017-08-29T16:40:41.928598 +c5f0eb63-a5d9-4b45-bc30-eb47a25590d0,Registration,2017-08-29T16:40:41.930963 +dfa09ad4-3efd-4a54-a628-25e1ccaded17,Date Picker Field List OQPS,2017-08-29T17:01:45.618568 +13a82b6a-2e82-4e1a-9f17-e57052b622e3,Date Picker Field List OQPS,2017-08-29T17:01:46.310261 +3407042d-4db0-4564-ba4c-f9e64a9fb17f,Date Picker Field List OQPS,2017-08-29T22:34:04.594038 +0bc4799d-ef00-4de8-9061-cca31e1e630e,Registration,2017-08-30T12:31:38.126856 +5ae9056c-a3b0-4b66-a66f-5830523fb6de,Registration,2017-08-30T20:13:18.057229 +d0cf0d73-d453-4245-9baa-d15577618f9f,Registration,2017-08-31T22:56:41.027155 +a6074fd8-9671-444c-ad10-4ea5f14a8b8e,Registration,2017-09-02T09:39:30.272531 +27b6f2e5-0891-4c55-8fd6-51c639c0cd87,Registration,2017-09-02T20:05:35.451835 +3012b9dc-5d1e-410a-aa39-803191e935ac,Registration,2017-09-02T20:05:35.459547 diff --git a/tests/009b_expected_form_2_data.csv b/tests/009b_expected_form_2_data.csv index c27897ad..bb2350b4 100644 --- a/tests/009b_expected_form_2_data.csv +++ b/tests/009b_expected_form_2_data.csv @@ -1,6 +1,4 @@ -id,name,received_on,server_modified_on -bbe20343-e00b-42c2-bede-b86342ed46dd,New Form,2012-04-02T18:38:50.000000Z,2012-04-02T18:38:50.000000Z -0492cb9d-b8e7-4628-9aff-c772a83b1c5b,New Form,2012-04-03T14:51:46.000000Z,2012-04-03T14:51:46.000000Z -162b6042-b96b-4008-8673-1d38b5771307,New Form,2012-04-18T20:07:22.000000Z,2012-04-18T20:07:22.000000Z -9e4f67e0-6d30-4f4b-9c0d-0dfe8bf691c1,New Form,2012-04-23T08:51:13.000000Z,2012-04-23T08:51:13.000000Z -68dd2433-8f52-4dca-a851-bc58f1d71f4a,New Form,2012-04-27T14:23:50.000000Z,2012-04-27T14:23:50.000000Z +id,name,inserted_at +d0cf1846-204b-4d04-819c-f688228c2c9e,Registration Form,2020-05-16T20:04:16.230195 +db38a72d-dd04-4893-9f2f-5548b8e1fa9f,Registration Form,2020-05-16T20:18:47.823616 +f34bec9a-0af3-495d-b53f-3d953e3b3d4b,Registration Form,2020-06-01T17:43:26.107701 diff --git a/tests/009b_integration_multiple.xlsx b/tests/009b_integration_multiple.xlsx index f8dd4667..1faa54dc 100644 Binary files a/tests/009b_integration_multiple.xlsx and b/tests/009b_integration_multiple.xlsx differ diff --git a/tests/013_ConflictingTypes.xlsx b/tests/013_ConflictingTypes.xlsx index a54a4b48..5fcd1937 100644 Binary files a/tests/013_ConflictingTypes.xlsx and b/tests/013_ConflictingTypes.xlsx differ diff --git a/tests/014_ExportWithDataTypes.xlsx b/tests/014_ExportWithDataTypes.xlsx index 67deff40..33fdd8c9 100644 Binary files a/tests/014_ExportWithDataTypes.xlsx and b/tests/014_ExportWithDataTypes.xlsx differ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 5f7f4f9e..decb3cb7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,22 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import logging import os import uuid -import pytest import sqlalchemy from sqlalchemy.exc import DBAPIError -TEST_DB = 'test_commcare_export_%s' % uuid.uuid4().hex +import pytest + +TEST_DB = f'test_commcare_export_{uuid.uuid4().hex}' logging.getLogger().setLevel(logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler()) def pytest_configure(config): - config.addinivalue_line("markers", "dbtest: mark test that requires database access") + config.addinivalue_line( + "markers", "dbtest: mark test that requires database access" + ) config.addinivalue_line("markers", "postgres: mark PostgreSQL test") config.addinivalue_line("markers", "mysql: mark MySQL test") config.addinivalue_line("markers", "mssql: mark MSSQL test") @@ -24,7 +24,10 @@ def pytest_configure(config): def _db_params(request, db_name): db_url = request.param['url'] - sudo_engine = sqlalchemy.create_engine(db_url % request.param.get('admin_db', ''), poolclass=sqlalchemy.pool.NullPool) + sudo_engine = sqlalchemy.create_engine( + db_url % request.param.get('admin_db', ''), + poolclass=sqlalchemy.pool.NullPool + ) db_connection_url = db_url % db_name def tear_down(): @@ -38,7 +41,10 @@ def tear_down(): try: with sqlalchemy.create_engine(db_connection_url).connect(): pass - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.InternalError, DBAPIError): + except ( + sqlalchemy.exc.OperationalError, sqlalchemy.exc.InternalError, + DBAPIError + ): with sudo_engine.connect() as conn: if 'postgres' in db_url: conn.execute('rollback') @@ -46,7 +52,9 @@ def tear_down(): conn.connection.connection.autocommit = True conn.execute('create database %s' % db_name) else: - raise Exception('Database %s already exists; refusing to overwrite' % db_name) + raise Exception( + 'Database %s already exists; refusing to overwrite' % db_name + ) request.addfinalizer(tear_down) @@ -55,34 +63,56 @@ def tear_down(): return params -postgres_base = os.environ.get('POSTGRES_URL', 'postgresql://postgres@localhost/') +postgres_base = os.environ.get( + 'POSTGRES_URL', 'postgresql://postgres@localhost/' +) mysql_base = os.environ.get('MYSQL_URL', 'mysql+pymysql://travis@/') -mssql_base = os.environ.get('MSSQL_URL', 'mssql+pyodbc://SA:Password@123@localhost/') - - -@pytest.fixture(scope="class", params=[ - pytest.param({ - 'url': "{}%s".format(postgres_base), - 'admin_db': 'postgres' - }, marks=pytest.mark.postgres), - pytest.param({ - 'url': '{}%s?charset=utf8'.format(mysql_base), - }, marks=pytest.mark.mysql), - pytest.param({ - 'url': '{}%s?driver=ODBC+Driver+17+for+SQL+Server'.format(mssql_base), - 'admin_db': 'master' - }, marks=pytest.mark.mssql) -], ids=['postgres', 'mysql', 'mssql']) +mssql_base = os.environ.get( + 'MSSQL_URL', 'mssql+pyodbc://SA:Password-123@localhost/' +) + + +@pytest.fixture( + scope="class", + params=[ + pytest.param( + { + 'url': f"{postgres_base}%s", + 'admin_db': 'postgres' + }, + marks=pytest.mark.postgres, + ), + pytest.param( + { + 'url': f'{mysql_base}%s?charset=utf8mb4', + }, + marks=pytest.mark.mysql, + ), + pytest.param( + { + 'url': + f'{mssql_base}%s?driver=ODBC+Driver+17+for+SQL+Server', + 'admin_db': + 'master' + }, + marks=pytest.mark.mssql, + ) + ], + ids=['postgres', 'mysql', 'mssql'] +) def db_params(request): return _db_params(request, TEST_DB) -@pytest.fixture(scope="class", params=[ - { - 'url': "{}%s".format(postgres_base), - 'admin_db': 'postgres' - }, -], ids=['postgres']) +@pytest.fixture( + scope="class", + params=[ + { + 'url': "{}%s".format(postgres_base), + 'admin_db': 'postgres' + }, + ], + ids=['postgres'] +) def pg_db_params(request): - return _db_params(request, 'test_commcare_export_%s' % uuid.uuid4().hex) - + return _db_params(request, f'test_commcare_export_{uuid.uuid4().hex}') diff --git a/tests/test_checkpointmanager.py b/tests/test_checkpointmanager.py index 91dbb843..19c54583 100644 --- a/tests/test_checkpointmanager.py +++ b/tests/test_checkpointmanager.py @@ -1,32 +1,55 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import datetime import uuid -import pytest import sqlalchemy -from commcare_export.checkpoint import CheckpointManager, Checkpoint, session_scope +import pytest +from commcare_export.checkpoint import ( + Checkpoint, + CheckpointManager, + CheckpointManagerProvider, + session_scope, +) +from commcare_export.commcare_minilinq import PaginationMode @pytest.fixture() def manager(db_params): - manager = CheckpointManager(db_params['url'], 'query', '123', 'test', 'hq', poolclass=sqlalchemy.pool.NullPool) + manager = CheckpointManager( + db_params['url'], + 'query', + '123', + 'test', + 'hq', + poolclass=sqlalchemy.pool.NullPool + ) try: yield manager finally: with manager: - manager.connection.execute(sqlalchemy.sql.text('DROP TABLE IF EXISTS commcare_export_runs')) - manager.connection.execute(sqlalchemy.sql.text('DROP TABLE IF EXISTS alembic_version')) + manager.connection.execute( + sqlalchemy.sql + .text('DROP TABLE IF EXISTS commcare_export_runs') + ) + manager.connection.execute( + sqlalchemy.sql.text('DROP TABLE IF EXISTS alembic_version') + ) + + +@pytest.fixture() +def configured_manager(manager): + manager.create_checkpoint_table() + return manager @pytest.mark.dbtest class TestCheckpointManager(object): + def test_create_checkpoint_table(self, manager, revision='head'): manager.create_checkpoint_table(revision) with manager: - assert 'commcare_export_runs' in manager.metadata.tables + table = manager.get_table('commcare_export_runs') + assert table is not None def test_checkpoint_table_exists(self, manager): # Test that the migrations don't fail for tables that existed before @@ -34,68 +57,75 @@ def test_checkpoint_table_exists(self, manager): # This test can be removed at some point in the future. self.test_create_checkpoint_table(manager, '9945abb4ec70') with manager: - manager.connection.execute(sqlalchemy.sql.text('DROP TABLE alembic_version')) + manager.connection.execute( + sqlalchemy.sql.text('DROP TABLE alembic_version') + ) manager.create_checkpoint_table() - def test_get_time_of_last_checkpoint(self, manager): - manager.create_checkpoint_table() - manager = manager.for_tables(['t1']) - manager.set_checkpoint(datetime.datetime.utcnow()) + def test_get_time_of_last_checkpoint(self, configured_manager): + manager = configured_manager.for_dataset('form', ['t1']) + manager.set_checkpoint( + datetime.datetime.utcnow(), PaginationMode.date_indexed + ) second_run = datetime.datetime.utcnow() - manager.set_checkpoint(second_run) + manager.set_checkpoint(second_run, PaginationMode.date_indexed) assert manager.get_time_of_last_checkpoint() == second_run.isoformat() - def test_get_last_checkpoint_no_args(self, manager): + def test_get_last_checkpoint_no_args(self, configured_manager): # test that we can still get the time of last run no project and commcare args - manager.create_checkpoint_table() - with session_scope(manager.Session) as session: + with session_scope(configured_manager.Session) as session: since_param = datetime.datetime.utcnow().isoformat() - session.add(Checkpoint( - id=uuid.uuid4().hex, - query_file_name=manager.query, - query_file_md5=manager.query_md5, - project=None, - commcare=None, - since_param=since_param, - time_of_run=datetime.datetime.utcnow().isoformat(), - final=True - )) - manager = manager.for_tables(['t1', 't2']) + session.add( + Checkpoint( + id=uuid.uuid4().hex, + query_file_name=configured_manager.query, + query_file_md5=configured_manager.query_md5, + project=None, + commcare=None, + since_param=since_param, + time_of_run=datetime.datetime.utcnow().isoformat(), + final=True + ) + ) + manager = configured_manager.for_dataset('form', ['t1', 't2']) checkpoint = manager.get_last_checkpoint() assert checkpoint.since_param == since_param assert checkpoint.project == manager.project assert checkpoint.commcare == manager.commcare assert len(manager.get_latest_checkpoints()) == 2 - def test_get_last_checkpoint_no_table(self, manager): + def test_get_last_checkpoint_no_table(self, configured_manager): # test that we can still get the time of last run no table # also tests that new checkoints are created with the tables - manager.create_checkpoint_table() - with session_scope(manager.Session) as session: + with session_scope(configured_manager.Session) as session: since_param = datetime.datetime.utcnow().isoformat() - session.add(Checkpoint( - id=uuid.uuid4().hex, - query_file_name=manager.query, - query_file_md5=manager.query_md5, - project=None, - commcare=None, - since_param=since_param, - time_of_run=datetime.datetime.utcnow().isoformat(), - final=True - )) - - session.add(Checkpoint( - id=uuid.uuid4().hex, - query_file_name=manager.query, - query_file_md5=manager.query_md5, - project=manager.project, - commcare=manager.commcare, - since_param=since_param, - time_of_run=datetime.datetime.utcnow().isoformat(), - final=True - )) - manager = manager.for_tables(['t1', 't2']) + session.add( + Checkpoint( + id=uuid.uuid4().hex, + query_file_name=configured_manager.query, + query_file_md5=configured_manager.query_md5, + project=None, + commcare=None, + since_param=since_param, + time_of_run=datetime.datetime.utcnow().isoformat(), + final=True + ) + ) + + session.add( + Checkpoint( + id=uuid.uuid4().hex, + query_file_name=configured_manager.query, + query_file_md5=configured_manager.query_md5, + project=configured_manager.project, + commcare=configured_manager.commcare, + since_param=since_param, + time_of_run=datetime.datetime.utcnow().isoformat(), + final=True + ) + ) + manager = configured_manager.for_dataset('form', ['t1', 't2']) checkpoint = manager.get_last_checkpoint() assert checkpoint.since_param == since_param assert checkpoint.table_name in manager.table_names @@ -103,59 +133,140 @@ def test_get_last_checkpoint_no_table(self, manager): assert len(checkpoints) == 2 assert {c.table_name for c in checkpoints} == set(manager.table_names) - def test_clean_on_final_run(self, manager): - manager.create_checkpoint_table() - manager = manager.for_tables(['t1']) - manager.set_checkpoint(datetime.datetime.utcnow()) - manager.set_checkpoint(datetime.datetime.utcnow()) + def test_clean_on_final_run(self, configured_manager): + manager = configured_manager.for_dataset('form', ['t1']) + manager.set_checkpoint( + datetime.datetime.utcnow(), + PaginationMode.date_indexed, + doc_id="1" + ) + manager.set_checkpoint( + datetime.datetime.utcnow(), + PaginationMode.date_indexed, + doc_id="2" + ) def _get_non_final_rows_count(): with session_scope(manager.Session) as session: return session.query(Checkpoint).filter_by(final=False).count() assert _get_non_final_rows_count() == 2 - manager.set_checkpoint(datetime.datetime.utcnow(), True) + manager.set_checkpoint( + datetime.datetime.utcnow(), + PaginationMode.date_indexed, + True, + doc_id="3" + ) assert _get_non_final_rows_count() == 0 - def test_get_time_of_last_checkpoint_with_key(self, manager): - manager.create_checkpoint_table() - manager = manager.for_tables(['t1']) + def test_get_time_of_last_checkpoint_with_key(self, configured_manager): + manager = configured_manager.for_dataset('form', ['t1']) manager.key = 'my key' last_run_time = datetime.datetime.utcnow() - manager.set_checkpoint(last_run_time) + manager.set_checkpoint(last_run_time, PaginationMode.date_indexed) - assert manager.get_time_of_last_checkpoint() == last_run_time.isoformat() + assert manager.get_time_of_last_checkpoint( + ) == last_run_time.isoformat() manager.key = None assert manager.get_time_of_last_checkpoint() is None - def test_multiple_tables(self, manager): - manager.create_checkpoint_table() + def test_multiple_tables(self, configured_manager): t1 = uuid.uuid4().hex t2 = uuid.uuid4().hex - manager = manager.for_tables([t1, t2]) + manager = configured_manager.for_dataset('form', [t1, t2]) last_run_time = datetime.datetime.utcnow() - manager.set_checkpoint(last_run_time) + doc_id = uuid.uuid4().hex + manager.set_checkpoint( + last_run_time, PaginationMode.date_indexed, doc_id=doc_id + ) - assert manager.for_tables([t1]).get_time_of_last_checkpoint() == last_run_time.isoformat() - assert manager.for_tables([t2]).get_time_of_last_checkpoint() == last_run_time.isoformat() - assert manager.for_tables(['t3']).get_last_checkpoint() is None + assert manager.for_dataset('form', [ + t1 + ]).get_time_of_last_checkpoint() == last_run_time.isoformat() + assert manager.for_dataset('form', [ + t2 + ]).get_time_of_last_checkpoint() == last_run_time.isoformat() + assert manager.for_dataset('form', + ['t3']).get_last_checkpoint() is None checkpoints = manager.list_checkpoints() assert len(checkpoints) == 2 - assert {checkpoints[0].table_name, checkpoints[1].table_name} == {t1, t2} + assert {checkpoints[0].table_name, + checkpoints[1].table_name} == {t1, t2} + assert {checkpoints[0].last_doc_id, + checkpoints[1].last_doc_id} == {doc_id} - def test_get_latest_checkpoints(self, manager): - manager.create_checkpoint_table() - manager = manager.for_tables(['t1', 't2']) - manager.set_checkpoint(datetime.datetime.utcnow()) + def test_get_latest_checkpoints(self, configured_manager): + manager = configured_manager.for_dataset('form', ['t1', 't2']) + manager.set_checkpoint( + datetime.datetime.utcnow(), PaginationMode.date_indexed + ) manager.query_md5 = '456' - manager.set_checkpoint(datetime.datetime.utcnow()) + manager.set_checkpoint( + datetime.datetime.utcnow(), PaginationMode.date_indexed + ) latest_time = datetime.datetime.utcnow() - manager.set_checkpoint(latest_time) + manager.set_checkpoint(latest_time, PaginationMode.date_indexed) checkpoints = manager.get_latest_checkpoints() assert len(checkpoints) == 2 assert [c.table_name for c in checkpoints] == ['t1', 't2'] assert {c.query_file_md5 for c in checkpoints} == {'456'} - assert {c.since_param for c in checkpoints} == {latest_time.isoformat()} + assert {c.since_param for c in checkpoints + } == {latest_time.isoformat()} + + +@pytest.mark.parametrize( + 'since, start_over, expected_since, expected_paginator', [ + (None, True, None, PaginationMode.date_indexed), + ('since', False, 'since', PaginationMode.date_indexed), + (None, False, None, PaginationMode.date_indexed), + ] +) +def test_checkpoint_details_static( + since, + start_over, + expected_since, + expected_paginator, +): + cmp = CheckpointManagerProvider(None, since, start_over) + assert expected_since == cmp.get_since(None) + assert expected_paginator == cmp.get_pagination_mode('', None) + + +@pytest.mark.dbtest +class TestCheckpointManagerProvider(object): + + def test_checkpoint_details_no_checkpoint(self, configured_manager): + manager = configured_manager.for_dataset('form', ['t1']) + assert None is CheckpointManagerProvider().get_since(manager) + assert PaginationMode.date_indexed == CheckpointManagerProvider( + ).get_pagination_mode('form', manager) + + def test_checkpoint_details_latest_from_db(self, configured_manager): + data_source = 'form' + manager = configured_manager.for_dataset(data_source, ['t1']) + + self._test_checkpoint_details( + manager, datetime.datetime.utcnow(), PaginationMode.date_modified, data_source + ) + self._test_checkpoint_details( + manager, datetime.datetime.utcnow(), PaginationMode.date_indexed, data_source + ) + self._test_checkpoint_details( + manager, datetime.datetime.utcnow(), PaginationMode.date_modified, data_source + ) + + def _test_checkpoint_details( + self, + manager, + checkpoint_date, + pagination_mode, + data_source, + ): + manager.set_checkpoint(checkpoint_date, pagination_mode) + + cmp = CheckpointManagerProvider() + assert pagination_mode == cmp.get_pagination_mode(data_source, manager) + assert checkpoint_date == cmp.get_since(manager) diff --git a/tests/test_cli.py b/tests/test_cli.py index 6835cc85..01be64bf 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,35 +1,36 @@ -# -*- coding: utf-8 -*- -import csv342 as csv +import csv import os import re import unittest from argparse import Namespace from copy import copy +from itertools import zip_longest +from unittest import mock -import pytest import sqlalchemy -from mock import mock +from tests.utils import SqlWriterWithTearDown -from commcare_export.checkpoint import CheckpointManager -from commcare_export.cli import CLI_ARGS, EXIT_STATUS_ERROR, main_with_args -from commcare_export.commcare_hq_client import MockCommCareHqClient +import pytest +from commcare_export.checkpoint import ( + Checkpoint, + CheckpointManager, + session_scope, +) +from commcare_export.cli import CLI_ARGS, main_with_args, validate_output_filename +from commcare_export.commcare_hq_client import ( + CommCareHqClient, + MockCommCareHqClient, + _params_to_url, +) +from commcare_export.commcare_minilinq import PaginationMode from commcare_export.specs import TableSpec -from commcare_export.writers import JValueTableWriter, SqlTableWriter - -CLI_ARGS_BY_NAME = { - arg.name: arg - for arg in CLI_ARGS -} - -try: - from itertools import izip_longest as zip_longest -except ImportError: - # PY 3 - from itertools import zip_longest +from commcare_export.writers import JValueTableWriter +CLI_ARGS_BY_NAME = {arg.name: arg for arg in CLI_ARGS} DEFAULT_BATCH_SIZE = 200 + def make_args(project='test', username='test', password='test', **kwargs): kwargs['project'] = project kwargs['username'] = username @@ -49,252 +50,295 @@ def make_args(project='test', username='test', password='test', **kwargs): def mock_hq_client(include_parent): return MockCommCareHqClient({ - 'form': [ - ( - {'limit': DEFAULT_BATCH_SIZE, 'order_by': ['server_modified_on', 'received_on']}, - [ - {'id': 1, 'form': {'name': 'f1', 'case': {'@case_id': 'c1'}}, - 'metadata': {'userID': 'id1'}}, - {'id': 2, 'form': {'name': 'f2', 'case': {'@case_id': 'c2'}}, - 'metadata': {'userID': 'id2'}}, - ] - ), - ], - 'case': [ - ( - {'limit': DEFAULT_BATCH_SIZE, 'order_by': 'server_date_modified'}, - [ - {'id': 'case1'}, - {'id': 'case2'}, - ] - ) - ], - 'user': [ - ( - {'limit': DEFAULT_BATCH_SIZE}, - [ - {'id': 'id1', 'email': 'em1', 'first_name': 'fn1', - 'last_name': 'ln1', - 'user_data': {'commcare_location_id': 'lid1', - 'commcare_location_ids': ['lid1', 'lid2'], - 'commcare_project': 'p1'}, - 'username': 'u1'}, - {'id': 'id2', 'default_phone_number': 'pn2', 'email': 'em2', - 'first_name': 'fn2', 'last_name': 'ln2', - 'resource_uri': 'ru0', - 'user_data': {'commcare_location_id': 'lid2', - 'commcare_project': 'p2'}, - 'username': 'u2'} - ] - ) - ], - 'location_type': [ - ( - {'limit': DEFAULT_BATCH_SIZE}, - [ - {'administrative': True, 'code': 'hq', 'domain': 'd1', 'id': 1, - 'name': 'HQ', 'parent': None, 'resource_uri': 'lt1', - 'shares_cases': False, 'view_descendants': True}, - {'administrative': False, 'code': 'local', 'domain': 'd1', - 'id': 2, 'name': 'Local', - 'parent': 'lt1', 'resource_uri': 'lt2', - 'shares_cases': True, 'view_descendants': True} - ] - ) - ], - 'location': [ - ( - {'limit': DEFAULT_BATCH_SIZE}, - [ - {'id': 'id1', 'created_at': '2020-04-01T21:57:26.403053', - 'domain': 'd1', 'external_id': 'eid1', - 'last_modified': '2020-04-01T21:58:23.88343', - 'latitude': '11.2', 'location_data': {'p1': 'ld1'}, - 'location_id': 'lid1', 'location_type': 'lt1', - 'longitude': '-20.5', 'name': 'n1', - 'resource_uri': 'ru1', 'site_code': 'sc1'}, - {'id': 'id2', 'created_at': '2020-04-01T21:58:47.627371', - 'domain': 'd2', 'last_modified': '2020-04-01T21:59:16.018411', - 'latitude': '-56.3', 'location_data': {'p1': 'ld2'}, - 'location_id': 'lid2', 'location_type': 'lt2', - 'longitude': '18.7', 'name': 'n2', - 'parent': 'ru1' if include_parent else None, - 'resource_uri': 'ru2', 'site_code': 'sc2'} - ] - ) - ], + 'form': [({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'indexed_on' + }, [ + { + 'id': 1, + 'form': { + 'name': 'f1', + 'case': { + '@case_id': 'c1' + } + }, + 'metadata': { + 'userID': 'id1' + } + }, + { + 'id': 2, + 'form': { + 'name': 'f2', + 'case': { + '@case_id': 'c2' + } + }, + 'metadata': { + 'userID': 'id2' + } + }, + ]),], + 'case': [({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'indexed_on' + }, [ + { + 'id': 'case1' + }, + { + 'id': 'case2' + }, + ])], + 'user': [({ + 'limit': DEFAULT_BATCH_SIZE + }, [{ + 'id': 'id1', + 'email': 'em1', + 'first_name': 'fn1', + 'last_name': 'ln1', + 'user_data': { + 'commcare_location_id': 'lid1', + 'commcare_location_ids': ['lid1', 'lid2'], + 'commcare_project': 'p1' + }, + 'username': 'u1' + }, { + 'id': 'id2', + 'default_phone_number': 'pn2', + 'email': 'em2', + 'first_name': 'fn2', + 'last_name': 'ln2', + 'resource_uri': 'ru0', + 'user_data': { + 'commcare_location_id': 'lid2', + 'commcare_project': 'p2' + }, + 'username': 'u2' + }])], + 'location_type': [({ + 'limit': DEFAULT_BATCH_SIZE + }, [{ + 'administrative': True, + 'code': 'hq', + 'domain': 'd1', + 'id': 1, + 'name': 'HQ', + 'parent': None, + 'resource_uri': 'lt1', + 'shares_cases': False, + 'view_descendants': True + }, { + 'administrative': False, + 'code': 'local', + 'domain': 'd1', + 'id': 2, + 'name': 'Local', + 'parent': 'lt1', + 'resource_uri': 'lt2', + 'shares_cases': True, + 'view_descendants': True + }])], + 'location': [({ + 'limit': DEFAULT_BATCH_SIZE + }, [{ + 'id': 'id1', + 'created_at': '2020-04-01T21:57:26.403053', + 'domain': 'd1', + 'external_id': 'eid1', + 'last_modified': '2020-04-01T21:58:23.88343', + 'latitude': '11.2', + 'location_data': { + 'p1': 'ld1' + }, + 'location_id': 'lid1', + 'location_type': 'lt1', + 'longitude': '-20.5', + 'name': 'n1', + 'resource_uri': 'ru1', + 'site_code': 'sc1' + }, { + 'id': 'id2', + 'created_at': '2020-04-01T21:58:47.627371', + 'domain': 'd2', + 'last_modified': '2020-04-01T21:59:16.018411', + 'latitude': '-56.3', + 'location_data': { + 'p1': 'ld2' + }, + 'location_id': 'lid2', + 'location_type': 'lt2', + 'longitude': '18.7', + 'name': 'n2', + 'parent': 'ru1' if include_parent else None, + 'resource_uri': 'ru2', + 'site_code': 'sc2' + }])], }) -EXPECTED_MULTIPLE_TABLES_RESULTS = [ - { - "name": "Forms", - "headings": ["id", "name"], - "rows": [ - ["1", "f1"], - ["2", "f2"] - ], - }, - { - "name": "Other cases", - "headings": ["id"], - "rows": [ - ["case1"], - ["case2"] - ], - }, - { - "name": "Cases", - "headings": ["case_id"], - "rows": [ - ["c1"], - ["c2"] - ], - } -] -EXPECTED_USERS_RESULTS = [ - { - "name": "commcare_users", - "headings": [ - "id", - "default_phone_number", - "email", - "first_name", - "groups", - "last_name", - "phone_numbers", - "resource_uri", - "commcare_location_id", - "commcare_location_ids", - "commcare_primary_case_sharing_id", - "commcare_project", - "username" - ], - "rows": [ - ["id1", None, "em1", "fn1", None, "ln1", None, None, "lid1", - "lid1,lid2", None, "p1", "u1"], - ["id2", "pn2", "em2", "fn2", None, "ln2", None, "ru0", "lid2", - None, None, "p2", "u2"] - ] - } -] +EXPECTED_MULTIPLE_TABLES_RESULTS = [{ + "name": "Forms", + "headings": ["id", "name"], + "rows": [["1", "f1"], ["2", "f2"]], +}, { + "name": "Other cases", + "headings": ["id"], + "rows": [["case1"], ["case2"]], +}, { + "name": "Cases", + "headings": ["case_id"], + "rows": [["c1"], ["c2"]], +}] + +EXPECTED_USERS_RESULTS = [{ + "name": + "commcare_users", + "headings": [ + "id", "default_phone_number", "email", "first_name", "groups", + "last_name", "phone_numbers", "resource_uri", "commcare_location_id", + "commcare_location_ids", "commcare_primary_case_sharing_id", + "commcare_project", "username" + ], + "rows": [[ + "id1", None, "em1", "fn1", None, "ln1", None, None, "lid1", + "lid1,lid2", None, "p1", "u1" + ], + [ + "id2", "pn2", "em2", "fn2", None, "ln2", None, "ru0", "lid2", + None, None, "p2", "u2" + ]] +}] + def get_expected_locations_results(include_parent): - return [ - { - "name": "commcare_locations", - "headings": [ - "id", - "created_at", - "domain", - "external_id", - "last_modified", - "latitude", - "location_data", - "location_id", - "location_type", - "longitude", - "name", - "parent", - "resource_uri", - "site_code", - "location_type_administrative", - "location_type_code", - "location_type_name", - "location_type_parent", - "local", - "hq" - ], - "rows": [ - ["id1", "2020-04-01 21:57:26", "d1", "eid1", - "2020-04-01 21:58:23", "11.2", '{"p1": "ld1", "id": "id1.location_data"}', "lid1", "lt1", - "-20.5", "n1", None, "ru1", "sc1", True, "hq", "HQ", None, - None, "lid1"], - ["id2", "2020-04-01 21:58:47", "d2", None, - "2020-04-01 21:59:16", "-56.3", '{"p1": "ld2", "id": "id2.location_data"}', "lid2", "lt2", - "18.7", "n2", ("ru1" if include_parent else None), "ru2", - "sc2", False, "local", "Local", "lt1", - "lid2", ("lid1" if include_parent else None)] - ] - } - ] + return [{ + "name": + "commcare_locations", + "headings": [ + "id", "created_at", "domain", "external_id", "last_modified", + "latitude", "location_data", "location_id", "location_type", + "longitude", "name", "parent", "resource_uri", "site_code", + "location_type_administrative", "location_type_code", + "location_type_name", "location_type_parent", "local", "hq" + ], + "rows": [[ + "id1", "2020-04-01 21:57:26", "d1", "eid1", "2020-04-01 21:58:23", + "11.2", '{"p1": "ld1", "id": "id1.location_data"}', "lid1", "lt1", + "-20.5", "n1", None, "ru1", "sc1", True, "hq", "HQ", None, None, + "lid1" + ], + [ + "id2", "2020-04-01 21:58:47", "d2", None, + "2020-04-01 21:59:16", "-56.3", + '{"p1": "ld2", "id": "id2.location_data"}', + "lid2", "lt2", "18.7", "n2", + ("ru1" if include_parent else None), "ru2", "sc2", False, + "local", "Local", "lt1", "lid2", + ("lid1" if include_parent else None) + ]] + }] class TestCli(unittest.TestCase): def _test_cli(self, args, expected): writer = JValueTableWriter() - with mock.patch('commcare_export.cli._get_writer', return_value=writer): + with mock.patch( + 'commcare_export.cli._get_writer', return_value=writer + ): main_with_args(args) for table in expected: assert writer.tables[table['name']] == TableSpec(**table) - - @mock.patch('commcare_export.cli._get_api_client', return_value=mock_hq_client(True)) + @mock.patch( + 'commcare_export.cli._get_api_client', + return_value=mock_hq_client(True) + ) def test_cli(self, mock_client): args = make_args( - query='tests/008_multiple-tables.xlsx', - output_format='json' + query='tests/008_multiple-tables.xlsx', output_format='json' ) self._test_cli(args, EXPECTED_MULTIPLE_TABLES_RESULTS) - @mock.patch('commcare_export.cli._get_api_client', return_value=mock_hq_client(True)) + @mock.patch( + 'commcare_export.cli._get_api_client', + return_value=mock_hq_client(True) + ) def test_cli_just_users(self, mock_client): - args = make_args( - output_format='json', - users=True - ) + args = make_args(output_format='json', users=True) self._test_cli(args, EXPECTED_USERS_RESULTS) - @mock.patch('commcare_export.cli._get_api_client', return_value=mock_hq_client(True)) + @mock.patch( + 'commcare_export.cli._get_api_client', + return_value=mock_hq_client(True) + ) def test_cli_table_plus_users(self, mock_client): args = make_args( query='tests/008_multiple-tables.xlsx', output_format='json', users=True ) - self._test_cli(args, EXPECTED_MULTIPLE_TABLES_RESULTS + - EXPECTED_USERS_RESULTS) + self._test_cli( + args, EXPECTED_MULTIPLE_TABLES_RESULTS + EXPECTED_USERS_RESULTS + ) - @mock.patch('commcare_export.cli._get_api_client', return_value=mock_hq_client(True)) + @mock.patch( + 'commcare_export.cli._get_api_client', + return_value=mock_hq_client(True) + ) def test_cli_just_locations(self, mock_client): - args = make_args( - output_format='json', - locations=True - ) + args = make_args(output_format='json', locations=True) self._test_cli(args, get_expected_locations_results(True)) - @mock.patch('commcare_export.cli._get_api_client', return_value=mock_hq_client(False)) + @mock.patch( + 'commcare_export.cli._get_api_client', + return_value=mock_hq_client(False) + ) def test_cli_locations_without_parents(self, mock_client): - args = make_args( - output_format='json', - locations=True - ) + args = make_args(output_format='json', locations=True) self._test_cli(args, get_expected_locations_results(False)) - @mock.patch('commcare_export.cli._get_api_client', return_value=mock_hq_client(True)) + @mock.patch( + 'commcare_export.cli._get_api_client', + return_value=mock_hq_client(True) + ) def test_cli_table_plus_locations(self, mock_client): args = make_args( query='tests/008_multiple-tables.xlsx', output_format='json', locations=True ) - self._test_cli(args, EXPECTED_MULTIPLE_TABLES_RESULTS + - get_expected_locations_results(True)) + self._test_cli( + args, EXPECTED_MULTIPLE_TABLES_RESULTS + + get_expected_locations_results(True) + ) -@pytest.fixture(scope='class') +@pytest.fixture(scope='function') def writer(pg_db_params): - return SqlTableWriter(pg_db_params['url'], poolclass=sqlalchemy.pool.NullPool) + writer = SqlWriterWithTearDown( + pg_db_params['url'], poolclass=sqlalchemy.pool.NullPool + ) + yield writer + writer.tear_down() -@pytest.fixture(scope='class') +@pytest.fixture(scope='function') def checkpoint_manager(pg_db_params): - cm = CheckpointManager(pg_db_params['url'], 'query', '123', 'test', 'hq', poolclass=sqlalchemy.pool.NullPool) + cm = CheckpointManager( + pg_db_params['url'], + 'query', + '123', + 'test', + 'hq', + poolclass=sqlalchemy.pool.NullPool + ) cm.create_checkpoint_table() return cm + def _pull_data(writer, checkpoint_manager, query, since, until, batch_size=10): args = make_args( query=query, @@ -307,40 +351,87 @@ def _pull_data(writer, checkpoint_manager, query, since, until, batch_size=10): batch_size=batch_size, since=since, until=until, + no_logfile=True, ) - # set this so that it get's written to the checkpoints + # set this so that it gets written to the checkpoints checkpoint_manager.query = query - # have to mock these to override the pool class otherwise they hold the db connection open - writer_patch = mock.patch('commcare_export.cli._get_writer', return_value=writer) - checkpoint_patch = mock.patch('commcare_export.cli._get_checkpoint_manager', return_value=checkpoint_manager) + # have to mock these to override the pool class otherwise they hold + # the db connection open + writer_patch = mock.patch( + 'commcare_export.cli._get_writer', return_value=writer + ) + checkpoint_patch = mock.patch( + 'commcare_export.cli._get_checkpoint_manager', + return_value=checkpoint_manager + ) with writer_patch, checkpoint_patch: main_with_args(args) +def _check_data(writer, expected, table_name, columns): + actual = [ + list(row) for row in writer.engine + .execute(f'SELECT {", ".join(columns)} FROM "{table_name}"') + ] + + message = '' + if actual != expected: + message += 'Data not equal to expected:\n' + if len(actual) != len(expected): + message += ' {} rows compared to {} expected\n'.format( + len(actual), len(expected) + ) + message += 'Diff:\n' + for i, rows in enumerate(zip_longest(actual, expected)): + if rows[0] != rows[1]: + message += '{}: {} != {}\n'.format(i, rows[0], rows[1]) + assert actual == expected, message + + @pytest.mark.dbtest class TestCLIIntegrationTests(object): - def test_write_to_sql_with_checkpoints(self, writer, checkpoint_manager, caplog): + + def test_write_to_sql_with_checkpoints( + self, writer, checkpoint_manager, caplog + ): with open('tests/009_expected_form_data.csv', 'r') as f: reader = csv.reader(f) expected_form_data = list(reader)[1:] - _pull_data(writer, checkpoint_manager, 'tests/009_integration.xlsx', '2012-01-01', '2012-08-01') + _pull_data( + writer, checkpoint_manager, 'tests/009_integration.xlsx', + '2012-01-01', '2017-08-29' + ) self._check_checkpoints(caplog, ['forms', 'batch', 'final']) - self._check_data(writer, expected_form_data[:16], 'forms') + self._check_data(writer, expected_form_data[:13], 'forms') caplog.clear() - _pull_data(writer, checkpoint_manager, 'tests/009_integration.xlsx', None, '2012-09-01', batch_size=8) + _pull_data( + writer, + checkpoint_manager, + 'tests/009_integration.xlsx', + None, + '2020-10-11', + batch_size=8 + ) self._check_data(writer, expected_form_data, 'forms') self._check_checkpoints(caplog, ['forms', 'batch', 'final']) - runs = list(writer.engine.execute( - 'SELECT * from commcare_export_runs where query_file_name = %s', 'tests/009_integration.xlsx' - )) + runs = list( + writer.engine.execute( + 'SELECT * FROM commcare_export_runs ' + 'WHERE query_file_name = %s', + + 'tests/009_integration.xlsx' + ) + ) assert len(runs) == 2, runs - def test_write_to_sql_with_checkpoints_multiple_tables(self, writer, checkpoint_manager, caplog): + def test_write_to_sql_with_checkpoints_multiple_tables( + self, writer, checkpoint_manager, caplog + ): with open('tests/009b_expected_form_1_data.csv', 'r') as f: reader = csv.reader(f) expected_form_1_data = list(reader)[1:] @@ -349,42 +440,43 @@ def test_write_to_sql_with_checkpoints_multiple_tables(self, writer, checkpoint_ reader = csv.reader(f) expected_form_2_data = list(reader)[1:] - _pull_data(writer, checkpoint_manager, 'tests/009b_integration_multiple.xlsx', None, '2012-05-01') - self._check_checkpoints(caplog, ['forms_1', 'final', 'forms_2', 'final']) - self._check_checkpoints(caplog, ['forms_1', 'forms_1', 'forms_2', 'forms_2']) + _pull_data( + writer, checkpoint_manager, 'tests/009b_integration_multiple.xlsx', + None, '2020-10-11' + ) + self._check_checkpoints( + caplog, ['forms_1', 'batch', 'batch', 'final', 'forms_2', 'final'] + ) + self._check_checkpoints( + caplog, + ['forms_1', 'forms_1', 'forms_1', 'forms_1', 'forms_2', 'forms_2'] + ) self._check_data(writer, expected_form_1_data, 'forms_1') self._check_data(writer, expected_form_2_data, 'forms_2') - runs = list(writer.engine.execute( - 'SELECT table_name, since_param from commcare_export_runs where query_file_name = %s', - 'tests/009b_integration_multiple.xlsx' - )) + runs = list( + writer.engine.execute( + 'SELECT table_name, since_param ' + 'FROM commcare_export_runs ' + 'WHERE query_file_name = %s', + + 'tests/009b_integration_multiple.xlsx' + ) + ) assert {r[0]: r[1] for r in runs} == { - 'forms_1': '2012-04-27T10:05:55', - 'forms_2': '2012-04-27T14:23:50' + 'forms_1': '2017-09-02T20:05:35.459547', + 'forms_2': '2020-06-01T17:43:26.107701', } def _check_data(self, writer, expected, table_name): - actual = [ - list(row) for row in - writer.engine.execute("SELECT id, name, received_on, server_modified_on FROM {}".format(table_name)) - ] - - message = '' - if actual != expected: - message += 'Data not equal to expected:\n' - if len(actual) != len(expected): - message += ' {} rows compared to {} expected\n'.format(len(actual), len(expected)) - message += 'Diff:\n' - for i, rows in enumerate(zip_longest(actual, expected)): - if rows[0] != rows[1]: - message += '{}: {} != {}\n'.format(i, rows[0], rows[1]) - assert actual == expected, message + _check_data(writer, expected, table_name, ['id', 'name', 'indexed_on']) def _check_checkpoints(self, caplog, expected): - # Depends on the logging in the CheckpointManager._set_checkpoint method + # Depends on the logging in the CheckpointManager._set_checkpoint + # method log_messages = [ - record[2] for record in caplog.record_tuples + record[2] + for record in caplog.record_tuples if record[0] == 'commcare_export.checkpoint' ] fail = False @@ -398,88 +490,492 @@ def _check_checkpoints(self, caplog, expected): assert not fail, 'Checkpoint comparison failed:\n' + message -# Conflicting types for 'count' will cause errors when inserting into database. +# Conflicting types for 'count' will cause errors when inserting into +# database. CONFLICTING_TYPES_CLIENT = MockCommCareHqClient({ - 'form': [ - ( - {'limit': DEFAULT_BATCH_SIZE, 'order_by': ['server_modified_on', 'received_on']}, - [ - {'id': 1, 'form': {'name': 'n1', 'count': 10}}, - {'id': 2, 'form': {'name': 'f2', 'count': 'abc'}} - ] - ), - ], + 'case': [({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'indexed_on' + }, [{ + 'id': 1, + 'name': 'n1', + 'count': 10 + }, { + 'id': 2, + 'name': 'f2', + 'count': 'abc' + }]),], }) -@pytest.fixture(scope='class') + +class MockCheckpointingClient(CommCareHqClient): + """ + Mock client that uses the main client for iteration but overrides + the data request to return mocked data. + + Note this client needs to be re-initialized after use. + """ + + def __init__(self, mock_data): + self.mock_data = { + resource: { + _params_to_url(params): result + for params, result in resource_results + } for resource, resource_results in mock_data.items() + } + self.totals = { + resource: sum(len(results) for _, results in resource_results + ) for resource, resource_results in mock_data.items() + } + + def get(self, resource, params=None): + mock_requests = self.mock_data[resource] + key = _params_to_url(params) + objects = mock_requests.pop(key) + if objects: + return { + 'meta': { + 'limit': len(objects), + 'next': bool(mock_requests), + 'offset': 0, + 'previous': None, + 'total_count': self.totals[resource] + }, + 'objects': objects + } + else: + return None + + +def get_conflicting_types_checkpoint_client(): + return MockCheckpointingClient({ + 'case': [ + ({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'indexed_on' + }, [{ + 'id': "doc 1", + 'name': 'n1', + 'count': 10, + 'indexed_on': '2012-04-23T05:13:01.000000Z' + }, { + 'id': "doc 2", + 'name': 'f2', + 'count': 123, + 'indexed_on': '2012-04-24T05:13:01.000000Z' + }]), + ({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'indexed_on', + 'indexed_on_start': '2012-04-24T05:13:01' + }, [{ + 'id': "doc 3", + 'name': 'n1', + 'count': 10, + 'indexed_on': '2012-04-25T05:13:01.000000Z' + }, { + 'id': "doc 4", + 'name': 'f2', + 'count': 'abc', + 'indexed_on': '2012-04-26T05:13:01.000000Z' + }]), + ], + }) + + +@pytest.fixture(scope='function') def strict_writer(db_params): - return SqlTableWriter(db_params['url'], poolclass=sqlalchemy.pool.NullPool, strict_types=True) + writer = SqlWriterWithTearDown( + db_params['url'], + poolclass=sqlalchemy.pool.NullPool, + strict_types=True + ) + yield writer + writer.tear_down() -@pytest.fixture(scope='class') + +@pytest.fixture(scope='function') def all_db_checkpoint_manager(db_params): - cm = CheckpointManager(db_params['url'], 'query', '123', 'test', 'hq', poolclass=sqlalchemy.pool.NullPool) + cm = CheckpointManager( + db_params['url'], + 'query', + '123', + 'test', + 'hq', + poolclass=sqlalchemy.pool.NullPool + ) cm.create_checkpoint_table() - return cm - -def _pull_mock_data(writer, checkpoint_manager, api_client, query): + yield cm + with session_scope(cm.Session) as session: + session.query(Checkpoint).delete(synchronize_session='fetch') + + +def _pull_mock_data( + writer, + checkpoint_manager, + api_client, + query, + start_over=None, + since=None +): args = make_args( query=query, output_format='sql', + start_over=start_over, + since=since, ) - # set this so that it get's written to the checkpoints - checkpoint_manager.query = query + assert not (checkpoint_manager and since), \ + "'checkpoint_manager' must be None when using 'since'" - # have to mock these to override the pool class otherwise they hold the db connection open - api_client_patch = mock.patch('commcare_export.cli._get_api_client', - return_value=api_client) - writer_patch = mock.patch('commcare_export.cli._get_writer', return_value=writer) - checkpoint_patch = mock.patch('commcare_export.cli._get_checkpoint_manager', return_value=checkpoint_manager) + if checkpoint_manager: + # set this so that it gets written to the checkpoints + checkpoint_manager.query = query + + # have to mock these to override the pool class otherwise they hold + # the db connection open + api_client_patch = mock.patch( + 'commcare_export.cli._get_api_client', return_value=api_client + ) + writer_patch = mock.patch( + 'commcare_export.cli._get_writer', return_value=writer + ) + checkpoint_patch = mock.patch( + 'commcare_export.cli._get_checkpoint_manager', + return_value=checkpoint_manager + ) with api_client_patch, writer_patch, checkpoint_patch: return main_with_args(args) + @pytest.mark.dbtest class TestCLIWithDatabaseErrors(object): - def test_cli_database_error(self, strict_writer, all_db_checkpoint_manager, capfd): - _pull_mock_data(strict_writer, all_db_checkpoint_manager, CONFLICTING_TYPES_CLIENT, 'tests/013_ConflictingTypes.xlsx') + + def test_cli_database_error( + self, strict_writer, all_db_checkpoint_manager, capfd + ): + _pull_mock_data( + strict_writer, all_db_checkpoint_manager, CONFLICTING_TYPES_CLIENT, + 'tests/013_ConflictingTypes.xlsx' + ) out, err = capfd.readouterr() expected_re = re.compile('Stopping because of database error') assert re.search(expected_re, out) + def test_cli_database_error_checkpoint( + self, strict_writer, all_db_checkpoint_manager, capfd + ): + _pull_mock_data( + strict_writer, all_db_checkpoint_manager, + get_conflicting_types_checkpoint_client(), + 'tests/013_ConflictingTypes.xlsx' + ) + out, err = capfd.readouterr() + + expected_re = re.compile('Stopping because of database error') + assert re.search(expected_re, out), out + + # expect checkpoint to have the date from the first batch and + # not the 2nd + runs = list( + strict_writer.engine.execute( + sqlalchemy.text( + 'SELECT table_name, since_param, last_doc_id ' + 'FROM commcare_export_runs ' + 'WHERE query_file_name = :file' + ), + file='tests/013_ConflictingTypes.xlsx' + ) + ) + assert runs == [ + ('Case', '2012-04-24T05:13:01', 'doc 2'), + ] + -# An input where missing fields should be added due to declared data types. +# An input where missing fields should be added due to declared data +# types. DATA_TYPES_CLIENT = MockCommCareHqClient({ - 'form': [ - ( - {'limit': DEFAULT_BATCH_SIZE, 'order_by': ['server_modified_on', 'received_on']}, - [ - {'id': 1, 'form': {}}, - {'id': 2, 'form': {}} - ] - ), - ], + 'form': [({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'indexed_on' + }, [{ + 'id': 1, + 'form': {} + }, { + 'id': 2, + 'form': {} + }]),], }) + @pytest.mark.dbtest class TestCLIWithDataTypes(object): - def test_cli_data_types_add_columns(self, strict_writer, all_db_checkpoint_manager, capfd): - _pull_mock_data(strict_writer, all_db_checkpoint_manager, CONFLICTING_TYPES_CLIENT, 'tests/014_ExportWithDataTypes.xlsx') - metadata = sqlalchemy.schema.MetaData(bind=strict_writer.engine, - reflect=True) + def test_cli_data_types_add_columns( + self, + writer, + all_db_checkpoint_manager, + capfd, + ): + _pull_mock_data( + writer, all_db_checkpoint_manager, CONFLICTING_TYPES_CLIENT, + 'tests/014_ExportWithDataTypes.xlsx' + ) - cols = metadata.tables['forms'].c - assert sorted([c.name for c in cols]) == sorted([u'id', u'a_bool', u'an_int', u'a_date', u'a_datetime', u'a_text']) + metadata = sqlalchemy.schema.MetaData(bind=writer.engine) + table = sqlalchemy.Table( + 'forms', + metadata, + autoload_with=writer.engine, + ) + cols = table.c + assert sorted([c.name for c in cols]) == sorted([ + u'id', u'a_bool', u'an_int', u'a_date', u'a_datetime', u'a_text' + ]) - # We intentionally don't check the types because SQLAlchemy doesn't - # support type comparison, and even if we convert to strings, the - # values are backend specific. + # We intentionally don't check the types because SQLAlchemy + # doesn't support type comparison, and even if we convert to + # strings, the values are backend specific. values = [ - list(row) for row in - strict_writer.engine.execute('SELECT * FROM forms') + list(row) for row in writer.engine.execute('SELECT * FROM forms') ] assert values == [['1', None, None, None, None, None], ['2', None, None, None, None, None]] + + +def get_indexed_on_client(page): + p1 = MockCheckpointingClient({ + 'case': [({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'indexed_on' + }, [{ + 'id': "doc 1", + 'name': 'n1', + 'indexed_on': '2012-04-23T05:13:01.000000Z' + }, { + 'id': "doc 2", + 'name': 'n2', + 'indexed_on': '2012-04-24T05:13:01.000000Z' + }])] + }) + p2 = MockCheckpointingClient({ + 'case': [({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'indexed_on', + 'indexed_on_start': '2012-04-24T05:13:01' + }, [{ + 'id': "doc 3", + 'name': 'n3', + 'indexed_on': '2012-04-25T05:13:01.000000Z' + }, { + 'id': "doc 4", + 'name': 'n4', + 'indexed_on': '2012-04-26T05:13:01.000000Z' + }])] + }) + return [p1, p2][page] + + +@pytest.mark.dbtest +class TestCLIPaginationMode(object): + + def test_cli_pagination_fresh(self, writer, all_db_checkpoint_manager): + checkpoint_manager = all_db_checkpoint_manager.for_dataset( + "case", ["Case"] + ) + + _pull_mock_data( + writer, all_db_checkpoint_manager, get_indexed_on_client(0), + 'tests/013_ConflictingTypes.xlsx' + ) + self._check_data(writer, [["doc 1"], ["doc 2"]], "Case") + self._check_checkpoint( + checkpoint_manager, '2012-04-24T05:13:01', 'doc 2' + ) + + _pull_mock_data( + writer, all_db_checkpoint_manager, get_indexed_on_client(1), + 'tests/013_ConflictingTypes.xlsx' + ) + self._check_data( + writer, [["doc 1"], ["doc 2"], ["doc 3"], ["doc 4"]], "Case" + ) + self._check_checkpoint( + checkpoint_manager, '2012-04-26T05:13:01', 'doc 4' + ) + + def test_cli_pagination_legacy(self, writer, all_db_checkpoint_manager): + """ + Test that we continue with the same pagination mode as was + already in use + """ + + checkpoint_manager = all_db_checkpoint_manager.for_dataset( + "case", ["Case"] + ) + # simulate previous run with legacy pagination mode + checkpoint_manager.set_checkpoint( + '2012-04-24T05:13:01', PaginationMode.date_modified, is_final=True + ) + + client = MockCheckpointingClient({ + 'case': [({ + 'limit': DEFAULT_BATCH_SIZE, + 'order_by': 'server_date_modified', + 'server_date_modified_start': '2012-04-24T05:13:01' + }, [{ + 'id': "doc 1", + 'name': 'n1', + 'server_date_modified': '2012-04-25T05:13:01.000000Z' + }, { + 'id': "doc 2", + 'name': 'n2', + 'server_date_modified': '2012-04-26T05:13:01.000000Z' + }])] + }) + + _pull_mock_data( + writer, all_db_checkpoint_manager, client, + 'tests/013_ConflictingTypes.xlsx' + ) + self._check_data(writer, [["doc 1"], ["doc 2"]], "Case") + self._check_checkpoint( + checkpoint_manager, '2012-04-26T05:13:01', 'doc 2', + PaginationMode.date_modified.name + ) + + def test_cli_pagination_start_over( + self, writer, all_db_checkpoint_manager + ): + """ + Test that we switch to the new pagination mode when using + 'start_over' + """ + checkpoint_manager = all_db_checkpoint_manager.for_dataset( + "case", ["Case"] + ) + # simulate previous run with legacy pagination mode + checkpoint_manager.set_checkpoint( + '2012-04-24T05:13:01', PaginationMode.date_modified, is_final=True + ) + + _pull_mock_data( + writer, + all_db_checkpoint_manager, + get_indexed_on_client(0), + 'tests/013_ConflictingTypes.xlsx', + start_over=True + ) + self._check_data(writer, [["doc 1"], ["doc 2"]], "Case") + self._check_checkpoint( + checkpoint_manager, '2012-04-24T05:13:01', 'doc 2' + ) + + def test_cli_pagination_since(self, writer, all_db_checkpoint_manager): + """ + Test that we use to the new pagination mode when using 'since' + """ + checkpoint_manager = all_db_checkpoint_manager.for_dataset( + "case", ["Case"] + ) + # simulate previous run with legacy pagination mode + checkpoint_manager.set_checkpoint( + '2012-04-28T05:13:01', PaginationMode.date_modified, is_final=True + ) + + # this will fail if it doesn't use the 'date_indexed' pagination + # mode due to how the mock client is set up + _pull_mock_data( + writer, + None, + get_indexed_on_client(1), + 'tests/013_ConflictingTypes.xlsx', + since='2012-04-24T05:13:01' + ) + self._check_data(writer, [["doc 3"], ["doc 4"]], "Case") + + def _check_data(self, writer, expected, table_name): + _check_data(writer, expected, table_name, ['id']) + + def _check_checkpoint( + self, + checkpoint_manager, + since_param, + doc_id, + pagination_mode=PaginationMode.date_indexed.name + ): + checkpoint = checkpoint_manager.get_last_checkpoint() + assert checkpoint.pagination_mode == pagination_mode + assert checkpoint.since_param == since_param + assert checkpoint.last_doc_id == doc_id + + +class TestValidateOutputFilename(unittest.TestCase): + def _test_file_extension(self, output_format, expected_extension): + error_message = (f"For output format as {output_format}, " + f"output file name should have extension {expected_extension}") + + errors = validate_output_filename( + output_format=output_format, + output_filename=f'correct_file_extension.{expected_extension}' + ) + self.assertEqual(len(errors), 0) + + errors = validate_output_filename( + output_format=output_format, + output_filename=f'incorrect_file_extension.abc' + ) + self.assertListEqual( + errors, + [error_message] + ) + + # incorrectly using sql output with non sql formats + errors = validate_output_filename( + output_format=output_format, + output_filename='postgresql+psycopg2://scott:tiger@localhost/mydatabase' + ) + self.assertListEqual( + errors, + [error_message] + ) + + def test_for_csv_output(self): + self._test_file_extension(output_format='csv', expected_extension='zip') + + def test_for_xls_output(self): + self._test_file_extension(output_format='xls', expected_extension='xls') + + def test_for_xlsx_output(self): + self._test_file_extension(output_format='xlsx', expected_extension='xlsx') + + def test_for_other_non_sql_output(self): + error_message = "Missing extension in output file name" + + errors = validate_output_filename( + output_format='non_sql', + output_filename='correct_file.abc' + ) + self.assertEqual(len(errors), 0) + + errors = validate_output_filename( + output_format='non_sql', + output_filename='filename_without_extensionxls' + ) + self.assertListEqual( + errors, + [error_message] + ) + + # incorrectly using sql output with non sql output formats + errors = validate_output_filename( + output_format='non_sql', + output_filename='postgresql+psycopg2://scott:tiger@localhost/mydatabase' + ) + self.assertListEqual( + errors, + [error_message] + ) diff --git a/tests/test_commcare_export.py b/tests/test_commcare_export.py new file mode 100644 index 00000000..cd367c82 --- /dev/null +++ b/tests/test_commcare_export.py @@ -0,0 +1,33 @@ +import os +from commcare_export import logger_name_from_filepath, repo_root + + +class TestLoggerNameFromFilePath: + + @staticmethod + def _file_path(rel_path): + return os.path.join(repo_root, rel_path) + + def test_file_in_root(self): + path = self._file_path("file.py") + assert logger_name_from_filepath(path) == 'file' + + def test_file_in_subdirectory(self): + path = self._file_path("subdir/file.py") + assert logger_name_from_filepath(path) == 'subdir.file' + + def test_file_in_deeper_subdirectory(self): + path = self._file_path("subdir/another_sub/file.py") + assert logger_name_from_filepath(path) == 'subdir.another_sub.file' + + def test_file_contains_py(self): + path = self._file_path("subdir/pytest.py") + assert logger_name_from_filepath(path) == 'subdir.pytest' + + def test_file_dir_contains_periods(self): + path = self._file_path("sub.dir/pytest.py") + assert logger_name_from_filepath(path) == 'sub.dir.pytest' + + def test_random_file_name(self): + path = self._file_path("pyppy.excel_query.py") + assert logger_name_from_filepath(path) == 'pyppy.excel_query' diff --git a/tests/test_commcare_hq_client.py b/tests/test_commcare_hq_client.py index 0b9ba172..5e0de962 100644 --- a/tests/test_commcare_hq_client.py +++ b/tests/test_commcare_hq_client.py @@ -1,21 +1,26 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - -import json import unittest from datetime import datetime -import simplejson - import requests +import simplejson import pytest - -from commcare_export.checkpoint import CheckpointManagerWithSince -from commcare_export.commcare_hq_client import CommCareHqClient, ResourceRepeatException -from commcare_export.commcare_minilinq import SimplePaginator, DatePaginator, resource_since_params, get_paginator - +from commcare_export.checkpoint import CheckpointManagerWithDetails +from commcare_export.commcare_hq_client import ( + CommCareHqClient, + ResourceRepeatException, +) +from commcare_export.commcare_minilinq import ( + DATE_PARAMS, + DatePaginator, + PaginationMode, + SimplePaginator, + get_paginator, +) +from mock import Mock, patch class FakeSession(object): + def get(self, resource_url, params=None, auth=None, timeout=None): result = self._get_results(params) # Mutatey construction method required by requests.Response @@ -28,93 +33,234 @@ def _get_results(self, params): if params: assert params['offset'][0] == '1' return { - 'meta': { 'next': None, 'offset': params['offset'][0], 'limit': 1, 'total_count': 2 }, - 'objects': [ {'id': 1, 'foo': 2} ] + 'meta': { + 'next': None, + 'offset': params['offset'][0], + 'limit': 1, + 'total_count': 2 + }, + 'objects': [{ + 'id': 1, + 'foo': 2 + }] } else: return { - 'meta': { 'next': '?offset=1', 'offset': 0, 'limit': 1, 'total_count': 2 }, - 'objects': [ {'id': 2, 'foo': 1} ] + 'meta': { + 'next': '?offset=1', + 'offset': 0, + 'limit': 1, + 'total_count': 2 + }, + 'objects': [{ + 'id': 2, + 'foo': 1 + }] } class FakeDateCaseSession(FakeSession): + def _get_results(self, params): if not params: return { - 'meta': {'next': '?offset=1', 'offset': 0, 'limit': 1, 'total_count': 2}, - 'objects': [{'id': 1, 'foo': 1, 'server_date_modified': '2017-01-01T15:36:22Z'}] + 'meta': { + 'next': '?offset=1', + 'offset': 0, + 'limit': 1, + 'total_count': 2 + }, + 'objects': [{ + 'id': 1, + 'foo': 1, + 'indexed_on': '2017-01-01T15:36:22Z' + }] } else: - since_query_param =resource_since_params['case'].start_param + since_query_param = DATE_PARAMS['indexed_on'].start_param assert params[since_query_param] == '2017-01-01T15:36:22' # include ID=1 again to make sure it gets filtered out return { - 'meta': { 'next': None, 'offset': 1, 'limit': 1, 'total_count': 2 }, - 'objects': [ {'id': 1, 'foo': 1}, {'id': 2, 'foo': 2} ] + 'meta': { + 'next': None, + 'offset': 1, + 'limit': 1, + 'total_count': 2 + }, + 'objects': [{ + 'id': 1, + 'foo': 1 + }, { + 'id': 2, + 'foo': 2 + }] } class FakeRepeatedDateCaseSession(FakeSession): # Model the case where there are as many or more cases with the same - # server_date_modified than the batch size (2), so the client requests - # the same set of cases in a loop. + # indexed_on than the batch size (2), so the client requests the + # same set of cases in a loop. def _get_results(self, params): if not params: return { - 'meta': {'next': '?offset=1', 'offset': 0, 'limit': 2, 'total_count': 4}, - 'objects': [{'id': 1, 'foo': 1, 'server_date_modified': '2017-01-01T15:36:22Z'}, - {'id': 2, 'foo': 2, 'server_date_modified': '2017-01-01T15:36:22Z'}] + 'meta': { + 'next': '?offset=1', + 'offset': 0, + 'limit': 2, + 'total_count': 4 + }, + 'objects': [{ + 'id': 1, + 'foo': 1, + 'indexed_on': '2017-01-01T15:36:22Z' + }, { + 'id': 2, + 'foo': 2, + 'indexed_on': '2017-01-01T15:36:22Z' + }] } else: - since_query_param =resource_since_params['case'].start_param + since_query_param = DATE_PARAMS['indexed_on'].start_param assert params[since_query_param] == '2017-01-01T15:36:22' return { - 'meta': { 'next': '?offset=1', 'offset': 0, 'limit': 2, 'total_count': 4}, - 'objects': [{'id': 1, 'foo': 1, 'server_date_modified': '2017-01-01T15:36:22Z'}, - {'id': 2, 'foo': 2, 'server_date_modified': '2017-01-01T15:36:22Z'}] + 'meta': { + 'next': '?offset=1', + 'offset': 0, + 'limit': 2, + 'total_count': 4 + }, + 'objects': [{ + 'id': 1, + 'foo': 1, + 'indexed_on': '2017-01-01T15:36:22Z' + }, { + 'id': 2, + 'foo': 2, + 'indexed_on': '2017-01-01T15:36:22Z' + }] + } + + +class FakeMessageLogSession(FakeSession): + + def _get_results(self, params): + obj_1 = { + 'id': 1, + 'foo': 1, + 'date_last_activity': '2017-01-01T15:36:22Z' + } + obj_2 = { + 'id': 2, + 'foo': 2, + 'date_last_activity': '2017-01-01T15:37:22Z' + } + obj_3 = { + 'id': 3, + 'foo': 3, + 'date_last_activity': '2017-01-01T15:38:22Z' + } + if not params: + return { + 'meta': { + 'next': '?cursor=xyz', + 'limit': 2 + }, + 'objects': [obj_1, obj_2] } + else: + since_query_param = DATE_PARAMS['date_last_activity'].start_param + since = params[since_query_param] + if since == '2017-01-01T15:37:22': + return { + 'meta': { + 'next': '?cursor=xyz', + 'limit': 2 + }, + 'objects': [obj_3] + } + if since == '2017-01-01T15:38:22': + return {'meta': {'next': None, 'limit': 2}, 'objects': []} + + raise Exception(since) class FakeDateFormSession(FakeSession): + def _get_results(self, params): since1 = '2017-01-01T15:36:22' since2 = '2017-01-01T16:00:00' if not params: return { - 'meta': {'next': '?offset=1', 'offset': 0, 'limit': 1, 'total_count': 3}, - 'objects': [{'id': 1, 'foo': 1, 'received_on': '{}Z'.format(since1)}] + 'meta': { + 'next': '?offset=1', + 'offset': 0, + 'limit': 1, + 'total_count': 3 + }, + 'objects': [{ + 'id': 1, + 'foo': 1, + 'indexed_on': '{}Z'.format(since1) + }] } else: - search = json.loads(params['_search']) - _or = search['filter']['or'] - received_on = _or[1]['and'][1]['range']['received_on']['gte'] - modified_on = _or[0]['and'][1]['range']['server_modified_on']['gte'] - if received_on == modified_on == since1: + since_query_param = DATE_PARAMS['indexed_on'].start_param + indexed_on = params[since_query_param] + if indexed_on == since1: # include ID=1 again to make sure it gets filtered out return { - 'meta': { 'next': '?offset=2', 'offset': 0, 'limit': 1, 'total_count': 3 }, - 'objects': [{'id': 1, 'foo': 1}, {'id': 2, 'foo': 2, 'server_modified_on': '{}Z'.format(since2)}] + 'meta': { + 'next': '?offset=2', + 'offset': 0, + 'limit': 1, + 'total_count': 3 + }, + 'objects': [{ + 'id': 1, + 'foo': 1 + }, { + 'id': 2, + 'foo': 2, + 'indexed_on': '{}Z'.format(since2) + }] } - elif received_on == modified_on == since2: + elif indexed_on == since2: return { - 'meta': { 'next': None, 'offset': 0, 'limit': 1, 'total_count': 3 }, - 'objects': [{'id': 3, 'foo': 3}] + 'meta': { + 'next': None, + 'offset': 0, + 'limit': 1, + 'total_count': 3 + }, + 'objects': [{ + 'id': 3, + 'foo': 3 + }] } else: - raise Exception(modified_on) + raise Exception(indexed_on) class TestCommCareHqClient(unittest.TestCase): def _test_iterate(self, session, paginator, expected_count, expected_vals): - client = CommCareHqClient('/fake/commcare-hq/url', 'fake-project', None, None) + client = CommCareHqClient( + '/fake/commcare-hq/url', 'fake-project', None, None + ) client.session = session - # Iteration should do two "gets" because the first will have something in the "next" metadata field + # Iteration should do two "gets" because the first will have + # something in the "next" metadata field paginator.init() - checkpoint_manager = CheckpointManagerWithSince(None, None) - results = list(client.iterate('/fake/uri', paginator, checkpoint_manager=checkpoint_manager)) + checkpoint_manager = CheckpointManagerWithDetails( + None, None, PaginationMode.date_indexed + ) + results = list( + client.iterate( + '/fake/uri', paginator, checkpoint_manager=checkpoint_manager + ) + ) self.assertEqual(len(results), expected_count) self.assertEqual([result['foo'] for result in results], expected_vals) @@ -122,13 +268,102 @@ def test_iterate_simple(self): self._test_iterate(FakeSession(), SimplePaginator('fake'), 2, [1, 2]) def test_iterate_date(self): - self._test_iterate(FakeDateFormSession(), get_paginator('form'), 3, [1, 2, 3]) - self._test_iterate(FakeDateCaseSession(), get_paginator('case'), 2, [1, 2]) + self._test_iterate( + FakeDateFormSession(), get_paginator('form'), 3, [1, 2, 3] + ) + self._test_iterate( + FakeDateCaseSession(), get_paginator('case'), 2, [1, 2] + ) def test_repeat_limit(self): - with pytest.raises(ResourceRepeatException, - match="Requested resource '/fake/uri' 10 times with same parameters"): - self._test_iterate(FakeRepeatedDateCaseSession(), get_paginator('case', 2), 2, [1, 2]) + with pytest.raises( + ResourceRepeatException, + match="Requested resource '/fake/uri' 10 times with same parameters" + ): + self._test_iterate( + FakeRepeatedDateCaseSession(), get_paginator('case', 2), 2, + [1, 2] + ) + + def test_message_log(self): + self._test_iterate( + FakeMessageLogSession(), get_paginator('messaging-event', 2), 3, + [1, 2, 3] + ) + + @patch("commcare_export.commcare_hq_client.CommCareHqClient.session") + def test_dont_raise_on_too_many_requests(self, session_mock): + response = requests.Response() + response.headers = {'Retry-After': "0.0"} + client = CommCareHqClient( + '/fake/commcare-hq/url', 'fake-project', None, None + ) + + self.assertFalse(client._should_raise_for_status(response)) + + @patch("commcare_export.commcare_hq_client.CommCareHqClient.session") + def test_raise_on_too_many_requests(self, session_mock): + response = requests.Response() + response.headers = {} + + client = CommCareHqClient( + '/fake/commcare-hq/url', 'fake-project', None, None + ) + + self.assertTrue(client._should_raise_for_status(response)) + + @patch('commcare_export.commcare_hq_client.logger') + @patch("commcare_export.commcare_hq_client.CommCareHqClient.session") + def test_get_with_forbidden_response_in_non_debug_mode(self, session_mock, logger_mock): + response = requests.Response() + response.status_code = 401 + session_mock.get.return_value = response + + logger_mock.isEnabledFor.return_value = False + + with self.assertRaises(SystemExit): + CommCareHqClient( + '/fake/commcare-hq/url', 'fake-project', None, None + ).get("location") + + logger_mock.error.assert_called_once_with( + "#401 Client Error: None for url: None. " + "Please ensure that your CommCare HQ credentials are correct and auth-mode is passed as 'apikey' " + "if using API Key to authenticate. Also, verify that your account has access to the project " + "and the necessary permissions to use commcare-export.") + + @patch('commcare_export.commcare_hq_client.logger') + @patch("commcare_export.commcare_hq_client.CommCareHqClient.session") + def test_get_with_other_http_failure_response_in_non_debug_mode(self, session_mock, logger_mock): + response = requests.Response() + response.status_code = 404 + session_mock.get.return_value = response + + logger_mock.isEnabledFor.return_value = False + + with self.assertRaises(SystemExit): + CommCareHqClient( + '/fake/commcare-hq/url', 'fake-project', None, None + ).get("location") + + logger_mock.error.assert_called_once_with( + "404 Client Error: None for url: None") + + @patch('commcare_export.commcare_hq_client.logger') + @patch("commcare_export.commcare_hq_client.CommCareHqClient.session") + def test_get_with_http_failure_response_in_debug_mode(self, session_mock, logger_mock): + response = requests.Response() + response.status_code = 404 + session_mock.get.return_value = response + + logger_mock.isEnabledFor.return_value = True + + try: + CommCareHqClient( + '/fake/commcare-hq/url', 'fake-project', None, None + ).get("location") + except Exception as e: + self.assertEqual(str(e), "404 Client Error: None for url: None") class TestDatePaginator(unittest.TestCase): @@ -138,27 +373,41 @@ def setup_class(cls): pass def test_empty_batch(self): - self.assertIsNone(DatePaginator('fake', 'since').next_page_params_from_batch({'objects': []})) + self.assertIsNone( + DatePaginator('since', params=SimplePaginator() + ).next_page_params_from_batch({'objects': []}) + ) def test_bad_date(self): - self.assertIsNone(DatePaginator('fake', 'since').next_page_params_from_batch({'objects': [{ - 'since': 'not a date' - }]})) + self.assertIsNone( + DatePaginator('since', params=SimplePaginator() + ).next_page_params_from_batch({ + 'objects': [{ + 'since': 'not a date' + }] + }) + ) def test_multi_field_sort(self): d1 = '2017-01-01T15:36:22Z' d2 = '2017-01-01T18:36:22Z' - self.assertEqual(DatePaginator('fake', ['s1', 's2']).get_since_date({'objects': [{ - 's1': d1, - 's2': d2 - }]}), datetime.strptime(d1, '%Y-%m-%dT%H:%M:%SZ')) - - self.assertEqual(DatePaginator('fake', ['s1', 's2']).get_since_date({'objects': [{ - 's2': d2 - }]}), datetime.strptime(d2, '%Y-%m-%dT%H:%M:%SZ')) + paginator = DatePaginator(['s1', 's2'], params=SimplePaginator()) + self.assertEqual( + paginator.get_since_date({'objects': [{ + 's1': d1, + 's2': d2 + }]}), datetime.strptime(d1, '%Y-%m-%dT%H:%M:%SZ') + ) - self.assertEqual(DatePaginator('fake', ['s1', 's2']).get_since_date({'objects': [{ - 's1': None, - 's2': d2 - }]}), datetime.strptime(d2, '%Y-%m-%dT%H:%M:%SZ')) + self.assertEqual( + paginator.get_since_date({'objects': [{ + 's2': d2 + }]}), datetime.strptime(d2, '%Y-%m-%dT%H:%M:%SZ') + ) + self.assertEqual( + paginator.get_since_date({'objects': [{ + 's1': None, + 's2': d2 + }]}), datetime.strptime(d2, '%Y-%m-%dT%H:%M:%SZ') + ) diff --git a/tests/test_commcare_minilinq.py b/tests/test_commcare_minilinq.py index 425bb928..516ab303 100644 --- a/tests/test_commcare_minilinq.py +++ b/tests/test_commcare_minilinq.py @@ -1,13 +1,13 @@ import unittest from itertools import * -from jsonpath_rw import jsonpath - -from commcare_export.checkpoint import CheckpointManagerWithSince -from commcare_export.minilinq import * -from commcare_export.env import * +from commcare_export.checkpoint import CheckpointManagerWithDetails from commcare_export.commcare_hq_client import MockCommCareHqClient from commcare_export.commcare_minilinq import * +from commcare_export.env import * +from commcare_export.minilinq import * +from jsonpath_ng import jsonpath + class TestCommCareMiniLinq(unittest.TestCase): @@ -17,106 +17,184 @@ def setup_class(cls): def check_case(self, val, result): if isinstance(result, list): - assert [datum.value if isinstance(datum, jsonpath.DatumInContext) else datum for datum in val] == result + assert [ + datum.value + if isinstance(datum, jsonpath.DatumInContext) else datum + for datum in val + ] == result + + def test_eval_indexed_on(self): + self._test_eval(PaginationMode.date_indexed) + + def test_eval_modified_on(self): + self._test_eval(PaginationMode.date_modified) + + def _test_eval(self, pagination_mode): + form_order_by = get_paginator( + 'form', pagination_mode=pagination_mode + ).since_field + case_order_by = get_paginator( + 'case', pagination_mode=pagination_mode + ).since_field + + def die(msg): + raise Exception(msg) - def test_eval(self): - def die(msg): raise Exception(msg) - client = MockCommCareHqClient({ 'form': [ ( - {'limit': 1000, 'filter': 'test1', 'order_by': ['server_modified_on', 'received_on']}, + { + 'limit': 1000, + 'filter': 'test1', + 'order_by': form_order_by + }, [1, 2, 3], ), - ( - {'limit': 1000, 'filter': 'test2', 'order_by': ['server_modified_on', 'received_on']}, - [ - { 'x': [{ 'y': 1 }, {'y': 2}] }, - { 'x': [{ 'y': 3 }, {'z': 4}] }, - { 'x': [{ 'y': 5 }] } - ] - ), - ( - {'limit': 1000, 'filter': 'laziness-test', 'order_by': ['server_modified_on', 'received_on']}, - (i if i < 5 else die('Not lazy enough') for i in range(12)) - ), - ( - {'limit': 1000, 'cases__full': 'true', 'order_by': ['server_modified_on', 'received_on']}, - [1, 2, 3, 4, 5] - ), - ], - - 'case': [ - ( - {'limit': 1000, 'type': 'foo', 'order_by': 'server_date_modified'}, - [ - { 'x': 1 }, - { 'x': 2 }, - { 'x': 3 }, - ] - ) + ({ + 'limit': 1000, + 'filter': 'test2', + 'order_by': form_order_by + }, [{ + 'x': [{ + 'y': 1 + }, { + 'y': 2 + }] + }, { + 'x': [{ + 'y': 3 + }, { + 'z': 4 + }] + }, { + 'x': [{ + 'y': 5 + }] + }]), + ({ + 'limit': 1000, + 'filter': 'laziness-test', + 'order_by': form_order_by + }, + (i if i < 5 else die('Not lazy enough') for i in range(12))), + ({ + 'limit': 1000, + 'cases__full': 'true', + 'order_by': form_order_by + }, [1, 2, 3, 4, 5]), ], - - 'user': [ - ( - {'limit': 1000}, - [ - { 'x': 1 }, - { 'x': 2 }, - { 'x': 3 }, - ] - ) - ] + 'case': [({ + 'limit': 1000, + 'type': 'foo', + 'order_by': case_order_by + }, [ + { + 'x': 1 + }, + { + 'x': 2 + }, + { + 'x': 3 + }, + ])], + 'user': [({ + 'limit': 1000 + }, [ + { + 'x': 1 + }, + { + 'x': 2 + }, + { + 'x': 3 + }, + ])] }) - env = BuiltInEnv() | CommCareHqEnv(client) | JsonPathEnv({}) # {'form': api_client.iterate('form')}) - - checkpoint_manager = CheckpointManagerWithSince(None, None) - assert list(Apply(Reference('api_data'), - Literal('form'), - Literal(checkpoint_manager), - Literal({"filter": 'test1'})).eval(env)) == [1, 2, 3] - - # just check that we can still apply some deeper xpath by mapping; first ensure the basics work - assert list(Apply(Reference('api_data'), - Literal('form'), - Literal(checkpoint_manager), - Literal({"filter": 'test2'})).eval(env)) == [ - { 'x': [{ 'y': 1 }, {'y': 2}] }, - { 'x': [{ 'y': 3 }, {'z': 4}] }, - { 'x': [{ 'y': 5 }] } - ] - - self.check_case(FlatMap(source=Apply(Reference('api_data'), - Literal('form'), - Literal(checkpoint_manager), - Literal({"filter": 'test2'})), - body=Reference('x[*].y')).eval(env), - [1, 2, 3, 5]) - - self.check_case(islice(Apply(Reference('api_data'), - Literal('form'), - Literal(checkpoint_manager), - Literal({"filter": "laziness-test"})).eval(env), 5), - [0, 1, 2, 3, 4]) - - self.check_case(Apply(Reference('api_data'), - Literal('form'), - Literal(checkpoint_manager), - Literal(None), - Literal(['cases'])).eval(env), - [1, 2, 3, 4, 5]) - - self.check_case(FlatMap(source=Apply(Reference('api_data'), - Literal('case'), - Literal(checkpoint_manager), - Literal({'type': 'foo'})), - body=Reference('x')).eval(env), - [1, 2, 3]) - - self.check_case(FlatMap(source=Apply(Reference('api_data'), - Literal('user'), - Literal(checkpoint_manager), - Literal(None)), - body=Reference('x')).eval(env), - [1, 2, 3]) + env = BuiltInEnv() | CommCareHqEnv(client) | JsonPathEnv( + {} + ) # {'form': api_client.iterate('form')}) + + checkpoint_manager = CheckpointManagerWithDetails( + None, None, pagination_mode + ) + assert list( + Apply( + Reference('api_data'), Literal('form'), + Literal(checkpoint_manager), Literal({"filter": 'test1'}) + ).eval(env) + ) == [1, 2, 3] + + # just check that we can still apply some deeper xpath by + # mapping; first ensure the basics work + assert list( + Apply( + Reference('api_data'), Literal('form'), + Literal(checkpoint_manager), Literal({"filter": 'test2'}) + ).eval(env) + ) == [{ + 'x': [{ + 'y': 1 + }, { + 'y': 2 + }] + }, { + 'x': [{ + 'y': 3 + }, { + 'z': 4 + }] + }, { + 'x': [{ + 'y': 5 + }] + }] + + self.check_case( + FlatMap( + source=Apply( + Reference('api_data'), Literal('form'), + Literal(checkpoint_manager), Literal({"filter": 'test2'}) + ), + body=Reference('x[*].y') + ).eval(env), [1, 2, 3, 5] + ) + + self.check_case( + islice( + Apply( + Reference('api_data'), Literal('form'), + Literal(checkpoint_manager), + Literal({"filter": "laziness-test"}) + ).eval(env), 5 + ), [0, 1, 2, 3, 4] + ) + + self.check_case( + Apply( + Reference('api_data'), Literal('form'), + Literal(checkpoint_manager), Literal(None), Literal(['cases']) + ).eval(env), [1, 2, 3, 4, 5] + ) + + self.check_case( + FlatMap( + source=Apply( + Reference('api_data'), Literal('case'), + Literal(checkpoint_manager), Literal({'type': 'foo'}) + ), + body=Reference('x') + ).eval(env), [1, 2, 3] + ) + + self.check_case( + FlatMap( + source=Apply( + Reference('api_data'), Literal('user'), + Literal(checkpoint_manager), Literal(None) + ), + body=Reference('x') + ).eval(env), [1, 2, 3] + ) diff --git a/tests/test_env.py b/tests/test_env.py new file mode 100644 index 00000000..1e63c295 --- /dev/null +++ b/tests/test_env.py @@ -0,0 +1,8 @@ +import doctest + +import commcare_export.env + + +def test_doctests(): + results = doctest.testmod(commcare_export.env) + assert results.failed == 0 diff --git a/tests/test_excel_query.py b/tests/test_excel_query.py index fcab6592..00176656 100644 --- a/tests/test_excel_query.py +++ b/tests/test_excel_query.py @@ -1,16 +1,13 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import os.path import pprint import unittest import openpyxl -from commcare_export.env import BuiltInEnv -from commcare_export.env import JsonPathEnv +from commcare_export.builtin_queries import ColumnEnforcer +from commcare_export.env import BuiltInEnv, JsonPathEnv from commcare_export.excel_query import * from commcare_export.excel_query import _get_safe_source_field -from commcare_export.builtin_queries import ColumnEnforcer class TestExcelQuery(unittest.TestCase): @@ -20,19 +17,35 @@ def setup_class(cls): pass def test_split_leftmost(self): - assert split_leftmost(parse_jsonpath('foo')) == (jsonpath.Fields('foo'), jsonpath.This()) - assert split_leftmost(parse_jsonpath('foo.baz')) == (jsonpath.Fields('foo'), jsonpath.Fields('baz')) - assert split_leftmost(parse_jsonpath('foo.baz.bar')) == (jsonpath.Fields('foo'), jsonpath.Fields('baz').child(jsonpath.Fields('bar'))) - assert split_leftmost(parse_jsonpath('[*].baz')) == (jsonpath.Slice(), jsonpath.Fields('baz')) - assert split_leftmost(parse_jsonpath('foo[*].baz')) == (jsonpath.Fields('foo'), jsonpath.Slice().child(jsonpath.Fields('baz'))) + assert split_leftmost( + parse_jsonpath('foo') + ) == (jsonpath.Fields('foo'), jsonpath.This()) + assert split_leftmost( + parse_jsonpath('foo.baz') + ) == (jsonpath.Fields('foo'), jsonpath.Fields('baz')) + assert split_leftmost(parse_jsonpath('foo.baz.bar')) == ( + jsonpath.Fields('foo'), + jsonpath.Fields('baz').child(jsonpath.Fields('bar')) + ) + assert split_leftmost( + parse_jsonpath('[*].baz') + ) == (jsonpath.Slice(), jsonpath.Fields('baz')) + assert split_leftmost(parse_jsonpath('foo[*].baz')) == ( + jsonpath.Fields('foo'), + jsonpath.Slice().child(jsonpath.Fields('baz')) + ) def test_get_safe_source_field(self): - assert _get_safe_source_field('foo.bar.baz') == Reference('foo.bar.baz') + assert _get_safe_source_field( + 'foo.bar.baz') == Reference('foo.bar.baz') assert _get_safe_source_field('foo[*].baz') == Reference('foo[*].baz') - assert _get_safe_source_field('foo..baz[*]') == Reference('foo..baz[*]') + assert _get_safe_source_field( + 'foo..baz[*]') == Reference('foo..baz[*]') assert _get_safe_source_field('foo.#baz') == Reference('foo."#baz"') - assert _get_safe_source_field('foo.bar[*]..%baz') == Reference('foo.bar[*].."%baz"') - assert _get_safe_source_field('foo.bar:1.baz') == Reference('foo."bar:1".baz') + assert _get_safe_source_field( + 'foo.bar[*]..%baz') == Reference('foo.bar[*].."%baz"') + assert _get_safe_source_field( + 'foo.bar:1.baz') == Reference('foo."bar:1".baz') try: assert _get_safe_source_field('foo.bar.') @@ -42,28 +55,29 @@ def test_get_safe_source_field(self): def test_compile_mappings(self): test_cases = [ - ('mappings.xlsx', - { - 'a': { - 'w': 12, - 'x': 13, - 'y': 14, - 'z': 15, - 'q': 16, - 'r': 17, - }, - 'b': { - 'www': 'hello', - 'xxx': 'goodbye', - 'yyy': 'what is up', - }, - 'c': { - 1: 'foo', - 2: 'bar', - 3: 'biz', - 4: 'bizzle', - } - }), + ( + 'mappings.xlsx', { + 'a': { + 'w': 12, + 'x': 13, + 'y': 14, + 'z': 15, + 'q': 16, + 'r': 17, + }, + 'b': { + 'www': 'hello', + 'xxx': 'goodbye', + 'yyy': 'what is up', + }, + 'c': { + 1: 'foo', + 2: 'bar', + 3: 'biz', + 4: 'bizzle', + } + } + ), ] def flatten(dd): @@ -74,7 +88,9 @@ def flatten(dd): for filename, mappings in test_cases: abs_path = os.path.join(os.path.dirname(__file__), filename) - compiled = compile_mappings(openpyxl.load_workbook(abs_path)['Mappings']) + compiled = compile_mappings( + openpyxl.load_workbook(abs_path)['Mappings'] + ) # Print will be suppressed by pytest unless it fails if not (flatten(compiled) == mappings): print('In %s:' % filename) @@ -86,92 +102,171 @@ def flatten(dd): def test_parse_sheet(self): test_cases = [ - ('001_JustDataSource.xlsx', SheetParts( - name='Forms', headings=[], source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), body=None), + ( + '001_JustDataSource.xlsx', + SheetParts( + name='Forms', + headings=[], + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), + body=None, + data_source="form" + ), ), - #('001a_JustDataSource_LibreOffice.xlsx', Emit(table='Forms', headings=[], source=Apply(Reference("api_data"), Literal("form")))), - - ('002_DataSourceAndFilters.xlsx', - SheetParts( - name='Forms', - headings=[], - source=Apply( - Reference("api_data"), - Literal("form"), - Reference("checkpoint_manager"), - Literal({ - 'app_id': 'foobizzle', - 'type': 'intake', - }) - ), - body=None - )), - - ('003_DataSourceAndEmitColumns.xlsx', - SheetParts( - name='Forms', - headings = [ - Literal('Form Type'), Literal('Fecha de Nacimiento'), Literal('Sexo'), - Literal('Danger 0'), Literal('Danger 1'), Literal('Danger Fever'), - Literal('Danger error'), Literal('Danger error'), Literal('special'), - Literal('Danger substring 1'), Literal('Danger substring 2'), - Literal('Danger substring error 3'), Literal('Danger substring error 4'), - Literal('Danger substring error 5') - ], - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), - body=List([ - Reference("type"), - Apply(Reference("FormatDate"), Reference("date_of_birth")), - Apply(Reference("sexo"), Reference("gender")), - Apply(Reference("selected-at"), Reference("dangers"), Literal(0)), - Apply(Reference("selected-at"), Reference("dangers"), Literal(1)), - Apply(Reference("selected"), Reference("dangers"), Literal('fever')), - Literal('Error: selected-at index must be an integer: selected-at(abc)'), - Literal('Error: Unable to parse: selected(fever'), - Reference('path."#text"'), - Apply(Reference("substr"), Reference("dangers"), Literal(0), Literal(10)), - Apply(Reference("substr"), Reference("dangers"), Literal(4), Literal(3)), - Literal('Error: both substr arguments must be non-negative integers: substr(a, b)'), - Literal('Error: both substr arguments must be non-negative integers: substr(-1, 10)'), - Literal('Error: both substr arguments must be non-negative integers: substr(3, -4)') - ]) - )), - - ('005_DataSourcePath.xlsx', - SheetParts( - name='Forms', - headings = [], - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), - body=None, - root_expr=Reference('form.delivery_information.child_questions.[*]') - )), - - ('006_IncludeReferencedItems.xlsx', - SheetParts( - name='Forms', - headings=[], - source=Apply( - Reference("api_data"), - Literal("form"), - Reference("checkpoint_manager"), - Literal(None), - Literal(['foo', 'bar', 'bizzle']) - ), - body=None - )), - - ('010_JustDataSourceTableName.xlsx', SheetParts( - name='my_table', headings=[], source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), body=None), + # ( + # '001a_JustDataSource_LibreOffice.xlsx', + # Emit( + # table='Forms', + # headings=[], + # source=Apply(Reference("api_data"), Literal("form")) + # ) + # ), + ( + '002_DataSourceAndFilters.xlsx', + SheetParts( + name='Forms', + headings=[], + source=Apply( + Reference("api_data"), Literal("form"), + Reference("checkpoint_manager"), + Literal({ + 'app_id': ['foobizzle'], + 'type': ['intake'], + }) + ), + body=None, + data_source="form" + ) + ), + ( + '003_DataSourceAndEmitColumns.xlsx', + SheetParts( + name='Forms', + headings=[ + Literal('Form Type'), + Literal('Fecha de Nacimiento'), + Literal('Sexo'), + Literal('Danger 0'), + Literal('Danger 1'), + Literal('Danger Fever'), + Literal('Danger error'), + Literal('Danger error'), + Literal('special'), + Literal('Danger substring 1'), + Literal('Danger substring 2'), + Literal('Danger substring error 3'), + Literal('Danger substring error 4'), + Literal('Danger substring error 5') + ], + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), + body=List([ + Reference("type"), + Apply( + Reference("FormatDate"), + Reference("date_of_birth") + ), + Apply(Reference("sexo"), Reference("gender")), + Apply( + Reference("selected-at"), Reference("dangers"), + Literal(0) + ), + Apply( + Reference("selected-at"), Reference("dangers"), + Literal(1) + ), + Apply( + Reference("selected"), Reference("dangers"), + Literal('fever') + ), + Literal( + 'Error: selected-at index must be an integer: ' + 'selected-at(abc)' + ), + Literal('Error: Unable to parse: selected(fever'), + Reference('path."#text"'), + Apply( + Reference("substr"), Reference("dangers"), + Literal(0), Literal(10) + ), + Apply( + Reference("substr"), Reference("dangers"), + Literal(4), Literal(3) + ), + Literal( + 'Error: both substr arguments must be ' + 'non-negative integers: substr(a, b)' + ), + Literal( + 'Error: both substr arguments must be ' + 'non-negative integers: substr(-1, 10)' + ), + Literal( + 'Error: both substr arguments must be ' + 'non-negative integers: substr(3, -4)' + ) + ]), + data_source="form" + ) + ), + ( + '005_DataSourcePath.xlsx', + SheetParts( + name='Forms', + headings=[], + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), + body=None, + root_expr=Reference( + 'form.delivery_information.child_questions.[*]' + ), + data_source="form" + ) + ), + ( + '006_IncludeReferencedItems.xlsx', + SheetParts( + name='Forms', + headings=[], + source=Apply( + Reference("api_data"), Literal("form"), + Reference("checkpoint_manager"), Literal(None), + Literal(['foo', 'bar', 'bizzle']) + ), + body=None, + data_source="form" + ) + ), + ( + '010_JustDataSourceTableName.xlsx', + SheetParts( + name='my_table', + headings=[], + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), + body=None, + data_source="form" + ) ), ] for filename, minilinq in test_cases: - print('Compiling sheet %s' % filename) # This output will be captured by pytest and printed in case of failure; helpful to isolate which test case + # This output will be captured by pytest and printed in case + # of failure; helpful to isolate which test case + print(f'Compiling sheet {filename}') abs_path = os.path.join(os.path.dirname(__file__), filename) compiled = parse_sheet(openpyxl.load_workbook(abs_path).active) # Print will be suppressed by pytest unless it fails if not (compiled == minilinq): - print('In %s:' % filename) + print(f'In {filename}:') pprint.pprint(compiled) print('!=') pprint.pprint(minilinq) @@ -180,25 +275,54 @@ def test_parse_sheet(self): def test_parse_workbook(self): field_mappings = {'t1': 'Form 1', 't2': 'Form 2'} test_cases = [ - ('004_TwoDataSources.xlsx', - [ - SheetParts(name='Forms', headings=[], source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), body=None), - SheetParts(name='Cases', headings=[], source=Apply(Reference("api_data"), Literal("case"), Reference('checkpoint_manager')), body=None) - ]), - ('007_Mappings.xlsx', - [ - SheetParts( - name='Forms', - headings=[Literal('Form Type')], - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), - body=List([compile_mapped_field(field_mappings, Reference("type"))]), - ) - ]), - + ( + '004_TwoDataSources.xlsx', [ + SheetParts( + name='Forms', + headings=[], + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), + body=None, + data_source="form" + ), + SheetParts( + name='Cases', + headings=[], + source=Apply( + Reference("api_data"), Literal("case"), + Reference('checkpoint_manager') + ), + body=None, + data_source="case" + ) + ] + ), + ( + '007_Mappings.xlsx', [ + SheetParts( + name='Forms', + headings=[Literal('Form Type')], + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), + body=List([ + compile_mapped_field( + field_mappings, Reference("type") + ) + ]), + data_source="form" + ) + ] + ), ] for filename, minilinq in test_cases: - print('Compiling workbook %s' % filename) # This output will be captured by pytest and printed in case of failure; helpful to isolate which test case + # This output will be captured by pytest and printed in case + # of failure; helpful to isolate which test case + print(f'Compiling workbook {filename}') abs_path = os.path.join(os.path.dirname(__file__), filename) compiled = parse_workbook(openpyxl.load_workbook(abs_path)) # Print will be suppressed by pytest unless it fails @@ -211,98 +335,179 @@ def test_parse_workbook(self): def test_compile_mapped_field(self): env = BuiltInEnv() | JsonPathEnv({'foo': {'bar': 'a', 'baz': 'b'}}) - expression = compile_mapped_field({'a': 'mapped from a'}, Reference('foo.bar')) + expression = compile_mapped_field({'a': 'mapped from a'}, + Reference('foo.bar')) assert expression.eval(env) == 'mapped from a' - expression = compile_mapped_field({'a': 'mapped from a'}, Reference('foo.baz')) + expression = compile_mapped_field({'a': 'mapped from a'}, + Reference('foo.baz')) assert list(expression.eval(env))[0].value == 'b' - expression = compile_mapped_field({'a': 'mapped from a'}, Reference('foo.boo')) + expression = compile_mapped_field({'a': 'mapped from a'}, + Reference('foo.boo')) assert list(expression.eval(env)) == [] def test_get_queries_from_excel(self): - minilinq = Bind('checkpoint_manager', Apply(Reference('get_checkpoint_manager'), Literal(["Forms"])), + minilinq = Bind( + 'checkpoint_manager', + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Forms"]) + ), Emit( - table='Forms', - missing_value='---', - headings =[ - Literal('Form Type'), Literal('Fecha de Nacimiento'), Literal('Sexo'), - Literal('Danger 0'), Literal('Danger 1'), Literal('Danger Fever'), - Literal('Danger error'), Literal('Danger error'), Literal('special'), - Literal('Danger substring 1'), Literal('Danger substring 2'), - Literal('Danger substring error 3'), Literal('Danger substring error 4'), - Literal('Danger substring error 5') - ], - source = Map( - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), - body = List([ - Reference("type"), - Apply(Reference("FormatDate"), Reference("date_of_birth")), - Apply(Reference("sexo"), Reference("gender")), - Apply(Reference("selected-at"), Reference("dangers"), Literal(0)), - Apply(Reference("selected-at"), Reference("dangers"), Literal(1)), - Apply(Reference("selected"), Reference("dangers"), Literal('fever')), - Literal('Error: selected-at index must be an integer: selected-at(abc)'), - Literal('Error: Unable to parse: selected(fever'), - Reference('path."#text"'), - Apply(Reference("substr"), Reference("dangers"), Literal(0), Literal(10)), - Apply(Reference("substr"), Reference("dangers"), Literal(4), Literal(3)), - Literal('Error: both substr arguments must be non-negative integers: substr(a, b)'), - Literal('Error: both substr arguments must be non-negative integers: substr(-1, 10)'), - Literal('Error: both substr arguments must be non-negative integers: substr(3, -4)') - ])) + table='Forms', + missing_value='---', + headings=[ + Literal('Form Type'), + Literal('Fecha de Nacimiento'), + Literal('Sexo'), + Literal('Danger 0'), + Literal('Danger 1'), + Literal('Danger Fever'), + Literal('Danger error'), + Literal('Danger error'), + Literal('special'), + Literal('Danger substring 1'), + Literal('Danger substring 2'), + Literal('Danger substring error 3'), + Literal('Danger substring error 4'), + Literal('Danger substring error 5') + ], + source=Map( + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), + body=List([ + Reference("type"), + Apply( + Reference("FormatDate"), + Reference("date_of_birth") + ), + Apply(Reference("sexo"), Reference("gender")), + Apply( + Reference("selected-at"), Reference("dangers"), + Literal(0) + ), + Apply( + Reference("selected-at"), Reference("dangers"), + Literal(1) + ), + Apply( + Reference("selected"), Reference("dangers"), + Literal('fever') + ), + Literal( + 'Error: selected-at index must be an integer: ' + 'selected-at(abc)' + ), + Literal('Error: Unable to parse: selected(fever'), + Reference('path."#text"'), + Apply( + Reference("substr"), Reference("dangers"), + Literal(0), Literal(10) + ), + Apply( + Reference("substr"), Reference("dangers"), + Literal(4), Literal(3) + ), + Literal( + 'Error: both substr arguments must be ' + 'non-negative integers: substr(a, b)' + ), + Literal( + 'Error: both substr arguments must be ' + 'non-negative integers: substr(-1, 10)' + ), + Literal( + 'Error: both substr arguments must be ' + 'non-negative integers: substr(3, -4)' + ) + ]) + ) ) ) - self._compare_minilinq_to_compiled(minilinq, '003_DataSourceAndEmitColumns.xlsx') + self._compare_minilinq_to_compiled( + minilinq, '003_DataSourceAndEmitColumns.xlsx' + ) def test_alternate_source_fields(self): minilinq = List([ # First sheet uses a CSV column and also tests combining "Map Via" - Bind('checkpoint_manager', Apply(Reference('get_checkpoint_manager'), Literal(["Forms"])), + Bind( + 'checkpoint_manager', + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Forms"]) + ), Emit( - table='Forms', missing_value='---', - headings =[ + table='Forms', + missing_value='---', + headings=[ Literal('dob'), ], - source = Map( - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), - body = List([ + source=Map( + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), + body=List([ Apply( Reference("str2date"), Apply( - Reference("or"), - Reference("dob"), Reference("date_of_birth"), Reference("d_o_b") + Reference("or"), Reference("dob"), + Reference("date_of_birth"), + Reference("d_o_b") ) ), - ])) + ]) + ) ) ), # Second sheet uses multiple alternate source field columns (listed out of order) - Bind('checkpoint_manager', Apply(Reference('get_checkpoint_manager'), Literal(["Forms1"])), + Bind( + 'checkpoint_manager', + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Forms1"]) + ), Emit( - table='Forms1', missing_value='---', + table='Forms1', + missing_value='---', headings=[ - Literal('dob'), Literal('Sex'), + Literal('dob'), + Literal('Sex'), ], source=Map( - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), body=List([ Reference("dob"), Apply( - Reference("or"), - Reference("gender"), Reference("sex"), Reference("sex0") + Reference("or"), Reference("gender"), + Reference("sex"), Reference("sex0") ) - ])) + ]) + ) ) ), ]) - self._compare_minilinq_to_compiled(minilinq, '011_AlternateSourceFields.xlsx') + self._compare_minilinq_to_compiled( + minilinq, '011_AlternateSourceFields.xlsx' + ) def test_columns_with_data_types(self): - minilinq = Bind('checkpoint_manager', Apply(Reference('get_checkpoint_manager'), Literal(["Forms"])), + minilinq = Bind( + 'checkpoint_manager', + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Forms"]) + ), Emit( table='Forms', missing_value='---', @@ -314,7 +519,10 @@ def test_columns_with_data_types(self): Literal('Bad Type'), ], source=Map( - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), body=List([ Reference("name"), Reference("date_of_birth"), @@ -332,22 +540,30 @@ def test_columns_with_data_types(self): ], ), ) - self._compare_minilinq_to_compiled(minilinq, '012_ColumnsWithTypes.xlsx') + self._compare_minilinq_to_compiled( + minilinq, '012_ColumnsWithTypes.xlsx' + ) def test_multi_emit(self): minilinq = List([ - Bind("checkpoint_manager", Apply(Reference('get_checkpoint_manager'), Literal(["Forms", "Cases"])), + Bind( + "checkpoint_manager", + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Forms", "Cases"]) + ), Filter( - predicate=Apply( - Reference("filter_empty"), - Reference("$") - ), + predicate=Apply(Reference("filter_empty"), Reference("$")), source=Map( - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), body=List([ Emit( table="Forms", - headings=[Literal("id"), Literal("name")], + headings=[Literal("id"), + Literal("name")], missing_value='---', source=Map( source=Reference("`this`"), @@ -374,40 +590,59 @@ def test_multi_emit(self): ), Bind( 'checkpoint_manager', - Apply(Reference('get_checkpoint_manager'), Literal(["Other cases"])), + Apply( + Reference('get_checkpoint_manager'), Literal("case"), + Literal(["Other cases"]) + ), Emit( table="Other cases", headings=[Literal("id")], missing_value='---', source=Map( - source=Apply(Reference("api_data"), Literal("case"), Reference('checkpoint_manager')), - body=List([ - Reference("id") - ]) + source=Apply( + Reference("api_data"), Literal("case"), + Reference('checkpoint_manager') + ), + body=List([Reference("id")]) ) ) ) ]) - self._compare_minilinq_to_compiled(minilinq, '008_multiple-tables.xlsx', combine=True) + self._compare_minilinq_to_compiled( + minilinq, '008_multiple-tables.xlsx', combine_emits=True + ) def test_multi_emit_no_combine(self): minilinq = List([ - Bind("checkpoint_manager", Apply(Reference('get_checkpoint_manager'), Literal(["Forms"])), - Emit( + Bind( + "checkpoint_manager", + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Forms"]) + ), + Emit( table="Forms", headings=[Literal("id"), Literal("name")], missing_value='---', source=Map( - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), body=List([ Reference("id"), Reference("form.name"), ]), ) - ) + ) ), - Bind("checkpoint_manager", Apply(Reference('get_checkpoint_manager'), Literal(["Cases"])), + Bind( + "checkpoint_manager", + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Cases"]) + ), Emit( table="Cases", headings=[Literal("case_id")], @@ -415,7 +650,10 @@ def test_multi_emit_no_combine(self): source=Map( source=FlatMap( body=Reference("form..case"), - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')) + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ) ), body=List([ Reference("@case_id"), @@ -423,37 +661,54 @@ def test_multi_emit_no_combine(self): ) ) ), - Bind("checkpoint_manager", Apply(Reference('get_checkpoint_manager'), Literal(["Other cases"])), + Bind( + "checkpoint_manager", + Apply( + Reference('get_checkpoint_manager'), Literal("case"), + Literal(["Other cases"]) + ), Emit( table="Other cases", headings=[Literal("id")], missing_value='---', source=Map( - source=Apply(Reference("api_data"), Literal("case"), Reference('checkpoint_manager')), - body=List([ - Reference("id") - ]) + source=Apply( + Reference("api_data"), Literal("case"), + Reference('checkpoint_manager') + ), + body=List([Reference("id")]) ) ) ) ]) - self._compare_minilinq_to_compiled(minilinq, '008_multiple-tables.xlsx', combine=False) + self._compare_minilinq_to_compiled( + minilinq, '008_multiple-tables.xlsx', combine_emits=False + ) def test_multi_emit_with_organization(self): minilinq = List([ - Bind("checkpoint_manager", Apply(Reference('get_checkpoint_manager'), Literal(["Forms", "Cases"])), + Bind( + "checkpoint_manager", + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Forms", "Cases"]) + ), Filter( - predicate=Apply( - Reference("filter_empty"), - Reference("$") - ), + predicate=Apply(Reference("filter_empty"), Reference("$")), source=Map( - source=Apply(Reference("api_data"), Literal("form"), Reference('checkpoint_manager')), + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), body=List([ Emit( table="Forms", - headings=[Literal("id"), Literal("name"), Literal("commcare_userid")], + headings=[ + Literal("id"), + Literal("name"), + Literal("commcare_userid") + ], missing_value='---', source=Map( source=Reference("`this`"), @@ -466,7 +721,10 @@ def test_multi_emit_with_organization(self): ), Emit( table="Cases", - headings=[Literal("case_id"), Literal("commcare_userid")], + headings=[ + Literal("case_id"), + Literal("commcare_userid") + ], missing_value='---', source=Map( source=Reference("form..case"), @@ -482,28 +740,121 @@ def test_multi_emit_with_organization(self): ), Bind( 'checkpoint_manager', - Apply(Reference('get_checkpoint_manager'), Literal(["Other cases"])), + Apply( + Reference('get_checkpoint_manager'), Literal("case"), + Literal(["Other cases"]) + ), Emit( table="Other cases", - headings=[Literal("id"), Literal("commcare_userid")], + headings=[Literal("id"), + Literal("commcare_userid")], + missing_value='---', + source=Map( + source=Apply( + Reference("api_data"), Literal("case"), + Reference('checkpoint_manager') + ), + body=List([Reference("id"), + Reference("$.user_id")]) + ) + ) + ) + ]) + + column_enforcer = ColumnEnforcer() + self._compare_minilinq_to_compiled( + minilinq, + '008_multiple-tables.xlsx', + combine_emits=True, + column_enforcer=column_enforcer + ) + + def test_value_or_root(self): + minilinq = List([ + Bind( + "checkpoint_manager", + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Forms"]) + ), + Emit( + table="Forms", + headings=[Literal("id"), Literal("name")], missing_value='---', source=Map( - source=Apply(Reference("api_data"), Literal("case"), Reference('checkpoint_manager')), + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ), body=List([ Reference("id"), - Reference("$.user_id") - ]) + Reference("form.name"), + ]), + ) + ) + ), + Bind( + "checkpoint_manager", + Apply( + Reference('get_checkpoint_manager'), Literal("form"), + Literal(["Cases"]) + ), + Emit( + table="Cases", + headings=[Literal("case_id")], + missing_value='---', + source=Map( + source=FlatMap( + body=Apply( + Reference("_or_raw"), Reference("form..case"), + Bind( + "__root_only", Literal(True), + Reference("$") + ) + ), + source=Apply( + Reference("api_data"), Literal("form"), + Reference('checkpoint_manager') + ) + ), + body=List([ + Reference("@case_id"), + ]), + ) + ) + ), + Bind( + "checkpoint_manager", + Apply( + Reference('get_checkpoint_manager'), Literal("case"), + Literal(["Other cases"]) + ), + Emit( + table="Other cases", + headings=[Literal("id")], + missing_value='---', + source=Map( + source=Apply( + Reference("api_data"), Literal("case"), + Reference('checkpoint_manager') + ), + body=List([Reference("id")]) ) ) ) ]) - column_enforcer = ColumnEnforcer() - self._compare_minilinq_to_compiled(minilinq, '008_multiple-tables.xlsx', combine=True, - column_enforcer=column_enforcer) + self._compare_minilinq_to_compiled( + minilinq, + '008_multiple-tables.xlsx', + combine_emits=False, + value_or_root=True + ) - def _compare_minilinq_to_compiled(self, minilinq, filename, combine=False, column_enforcer=None): + def _compare_minilinq_to_compiled(self, minilinq, filename, **kwargs): print("Parsing {}".format(filename)) abs_path = os.path.join(os.path.dirname(__file__), filename) - compiled = get_queries_from_excel(openpyxl.load_workbook(abs_path), missing_value='---', combine_emits=combine, column_enforcer=column_enforcer) + compiled = get_queries_from_excel( + openpyxl.load_workbook(abs_path), missing_value='---', **kwargs + ) assert compiled.to_jvalue() == minilinq.to_jvalue(), filename diff --git a/tests/test_map_format.py b/tests/test_map_format.py index c91a541f..d662aec8 100644 --- a/tests/test_map_format.py +++ b/tests/test_map_format.py @@ -1,30 +1,53 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - import unittest -from commcare_export.map_format import ( - parse_template, - parse_function_arg, -) -from commcare_export.minilinq import Apply, Reference, Literal +from commcare_export.map_format import parse_function_arg, parse_template +from commcare_export.minilinq import Apply, Literal, Reference class TestMapFormats(unittest.TestCase): + def test_parse_template_no_args(self): - expected = Apply(Reference('template'), Literal('my name is {}'), Reference('form.question1')) - assert parse_template(Reference('form.question1'), 'template(my name is {})') == expected + expected = Apply( + Reference('template'), Literal('my name is {}'), + Reference('form.question1') + ) + assert parse_template( + Reference('form.question1'), 'template(my name is {})' + ) == expected def test_parse_template_args(self): - expected = Apply(Reference('template'), Literal('my name is {}'), Reference('form.question2')) - assert parse_template('form.question1', 'template(my name is {}, form.question2)') == expected + expected = Apply( + Reference('template'), Literal('my name is {}'), + Reference('form.question2') + ) + assert parse_template( + 'form.question1', 'template(my name is {}, form.question2)' + ) == expected + + def test_parse_template_args_long(self): + expected = Apply( + Reference('template'), + Literal('https://www.commcarehq.org/a/{}/reports/form_data/{}/'), + Reference('$.domain'), + Reference('$.id'), + ) + assert parse_template( + 'form.id', + + 'template(https://www.commcarehq.org/a/{}/reports/form_data/{}/, ' + '$.domain, $.id)' + ) == expected def test_parse_template_no_template(self): - expected = Literal('Error: template function requires the format template: template()') + expected = Literal( + 'Error: template function requires the format template: template()' + ) assert parse_template('form.question1', 'template()') == expected def test_parse_function_arg_with_brackets(self): - value_returned = parse_function_arg('selected', 'selected(Other_(Specify))') + value_returned = parse_function_arg( + 'selected', 'selected(Other_(Specify))' + ) assert value_returned == 'Other_(Specify)' def test_parse_function_arg_empty_returns(self): diff --git a/tests/test_minilinq.py b/tests/test_minilinq.py index 6d257c13..b814d145 100644 --- a/tests/test_minilinq.py +++ b/tests/test_minilinq.py @@ -1,20 +1,24 @@ -# -*- coding: utf-8 -*- +import types import unittest +from datetime import datetime from itertools import * import pytest -from six.moves import map, xrange - -from jsonpath_rw import jsonpath - -from commcare_export.minilinq import * -from commcare_export.repeatable_iterator import RepeatableIterator from commcare_export.env import * +from commcare_export.excel_query import get_value_or_root_expression +from commcare_export.minilinq import * from commcare_export.writers import JValueTableWriter -class LazinessException(Exception): pass -def die(msg): raise LazinessException(msg) # Hack: since "raise" is a statement not an expression, need a funcall wrapping it +class LazinessException(Exception): + pass + + +def die(msg): + # Hack: since "raise" is a statement not an expression, need a + # funcall wrapping it + raise LazinessException(msg) + class TestMiniLinq(unittest.TestCase): @@ -24,7 +28,9 @@ def setup_class(cls): def check_case(self, val, expected): if isinstance(expected, list): - assert [datum.value if isinstance(datum, jsonpath.DatumInContext) else datum for datum in val] == expected + assert [unwrap_val(datum) for datum in val] == expected + else: + assert val == expected def test_eval_literal(self): env = BuiltInEnv() @@ -35,30 +41,200 @@ def test_eval_literal(self): def test_eval_reference(self): env = BuiltInEnv() assert Reference("foo").eval(DictEnv({'foo': 2})) == 2 - assert Reference(Reference(Reference('a'))).eval(DictEnv({'a': 'b', 'b': 'c', 'c': 2})) == 2 - self.check_case(Reference("foo[*]").eval(JsonPathEnv({'foo': [2]})), [2]) - self.check_case(Reference("foo[*]").eval(JsonPathEnv({'foo': xrange(0, 1)})), [0]) # Should work the same w/ iterators as with lists - - # Should be able to get back out to the root, as the JsonPathEnv actually passes the full datum around - self.check_case(Reference("foo.$.baz").eval(JsonPathEnv({'foo': [2], 'baz': 3})), [3]) + assert Reference(Reference(Reference('a')) + ).eval(DictEnv({ + 'a': 'b', + 'b': 'c', + 'c': 2 + })) == 2 + self.check_case( + Reference("foo[*]").eval(JsonPathEnv({'foo': [2]})), [2] + ) + # Should work the same w/ iterators as with lists + self.check_case( + Reference("foo[*]").eval(JsonPathEnv({'foo': range(0, 1)})), [0] + ) + + # Should be able to get back out to the root, as the JsonPathEnv + # actually passes the full datum around + self.check_case( + Reference("foo.$.baz").eval(JsonPathEnv({ + 'foo': [2], + 'baz': 3 + })), [3] + ) def test_eval_auto_id_reference(self): - "Test that we have turned on the jsonpath_rw.jsonpath.auto_id field properly" + """ + Test that we have turned on the jsonpath_ng.jsonpath.auto_id + field properly + """ env = BuiltInEnv() - self.check_case(Reference("foo.id").eval(JsonPathEnv({'foo': [2]})), ['foo']) + self.check_case( + Reference("foo.id").eval(JsonPathEnv({'foo': [2]})), ['foo'] + ) # When auto id is on, this always becomes a string. Sorry! - self.check_case(Reference("foo.id").eval(JsonPathEnv({'foo': {'id': 2}})), ['2']) + self.check_case( + Reference("foo.id").eval(JsonPathEnv({'foo': { + 'id': 2 + }})), ['2'] + ) + + def test_eval_auto_id_reference_nested(self): + # this test is documentation of existing (weird) functionality + # that results from a combination of jsonpath_ng auto_id feature + # and JsonPathEnv.lookup (which adds an additional auto ID for + # some reason). + env = JsonPathEnv({}) + + flatmap = FlatMap( + source=Literal([{ + "id": 1, + "foo": { + 'id': 'bid', + 'name': 'bob' + }, + "bar": [{ + 'baz': 'a1' + }, { + 'baz': 'a2', + 'id': 'bazzer' + }] + }]), + body=Reference('bar.[*]') + ) + mmap = Map( + source=flatmap, + body=List([ + Reference("id"), + Reference('baz'), + Reference('$.id'), + Reference('$.foo.id'), + Reference('$.foo.name') + ]) + ) + self.check_case( + mmap.eval(env), [["1.bar.'1.bar.[0]'", 'a1', '1', '1.bid', 'bob'], + ['1.bar.bazzer', 'a2', '1', '1.bid', 'bob']] + ) + + # Without the additional auto id field added in JsonPathEnv the + # result for Reference("id") changes as follows: + # '1.bar.1.bar.[0]' -> '1.bar.[0]' + + # With the change above AND a change to jsonpath_ng to prevent + # converting IDs that exist into auto IDs (see + # https://github.com/kennknowles/python-jsonpath-rw/pull/96) we + # get the following: + # Reference("id"): + # '1.bar.bazzer' -> 'bazzer' + # + # Reference('$.foo.id'): + # '1.bid' -> 'bid' + + def test_value_or_root(self): + """ + Test that when accessing a child object the child data is used + if it exists (normal case). + """ + data = {"id": 1, "bar": [{'baz': 'a1'}, {'baz': 'a2'}]} + self._test_value_or_root([Reference('id'), + Reference('baz')], data, [ + ["1.bar.'1.bar.[0]'", 'a1'], + ["1.bar.'1.bar.[1]'", 'a2'], + ]) + + def test_value_or_root_empty_list(self): + """Should use the root object if the child is an empty list""" + data = { + "id": 1, + "foo": "I am foo", + "bar": [], + } + self._test_value_or_root([ + Reference('id'), + Reference('baz'), + Reference('$.foo') + ], data, [ + ['1', [], "I am foo"], + ]) + + def test_value_or_root_empty_dict(self): + """Should use the root object if the child is an empty dict""" + data = { + "id": 1, + "foo": "I am foo", + "bar": {}, + } + self._test_value_or_root([ + Reference('id'), + Reference('baz'), + Reference('$.foo') + ], data, [ + ['1', [], "I am foo"], + ]) + + def test_value_or_root_None(self): + """Should use the root object if the child is None""" + data = { + "id": 1, + "bar": None, + } + self._test_value_or_root([Reference('id'), + Reference('baz')], data, [ + ['1', []], + ]) + + def test_value_or_root_missing(self): + """Should use the root object if the child does not exist""" + data = { + "id": 1, + "foo": "I am foo", + # 'bar' is missing + } + self._test_value_or_root([ + Reference('id'), + Reference('baz'), + Reference('$.foo') + ], data, [ + ['1', [], 'I am foo'], + ]) + + def test_value_or_root_ignore_field_in_root(self): + """ + Test that a child reference is ignored if we are using the root + doc even if there is a field with that name. (this doesn't apply + to 'id') + """ + data = { + "id": 1, + "foo": "I am foo", + } + self._test_value_or_root([Reference('id'), + Reference('foo')], data, [ + ['1', []], + ]) + + def _test_value_or_root(self, columns, data, expected): + """Low level test case for 'value-or-root'""" + env = BuiltInEnv() | JsonPathEnv({}) + value_or_root = get_value_or_root_expression('bar.[*]') + flatmap = FlatMap(source=Literal([data]), body=value_or_root) + mmap = Map(source=flatmap, body=List(columns)) + self.check_case(mmap.eval(env), expected) def test_eval_collapsed_list(self): """ - Special case to handle XML -> JSON conversion where there just happened to be a single value at save time + Special case to handle XML -> JSON conversion where there just + happened to be a single value at save time """ env = BuiltInEnv() self.check_case(Reference("foo[*]").eval(JsonPathEnv({'foo': 2})), [2]) assert Apply(Reference("*"), Literal(2), Literal(3)).eval(env) == 6 - assert Apply(Reference(">"), Literal(56), Literal(23.5)).eval(env) == True + assert Apply(Reference(">"), Literal(56), + Literal(23.5)).eval(env) == True assert Apply(Reference("len"), Literal([1, 2, 3])).eval(env) == 3 assert Apply(Reference("bool"), Literal('a')).eval(env) == True assert Apply(Reference("bool"), Literal('')).eval(env) == False @@ -66,130 +242,315 @@ def test_eval_collapsed_list(self): assert Apply(Reference("str2bool"), Literal('t')).eval(env) == True assert Apply(Reference("str2bool"), Literal('1')).eval(env) == True assert Apply(Reference("str2bool"), Literal('0')).eval(env) == False - assert Apply(Reference("str2bool"), Literal('false')).eval(env) == False + assert Apply(Reference("str2bool"), + Literal('false')).eval(env) == False assert Apply(Reference("str2bool"), Literal(u'日本')).eval(env) == False assert Apply(Reference("str2num"), Literal('10')).eval(env) == 10 assert Apply(Reference("str2num"), Literal('10.56')).eval(env) == 10.56 assert Apply(Reference("str2num"), Literal('')).eval(env) == None - assert Apply(Reference("str2date"), Literal('2015-01-01')).eval(env) == datetime(2015, 1, 1) - assert Apply(Reference("str2date"), Literal('2015-01-01T18:32:57')).eval(env) == datetime(2015, 1, 1, 18, 32, 57) - assert Apply(Reference("str2date"), Literal('2015-01-01T18:32:57.001200')).eval(env) == datetime(2015, 1, 1, 18, 32, 57) - assert Apply(Reference("str2date"), Literal('2015-01-01T18:32:57.001200Z')).eval(env) == datetime(2015, 1, 1, 18, 32, 57) - assert Apply(Reference("str2date"), Literal(u'日'.encode('utf8'))).eval(env) == None + assert Apply(Reference("str2date"), + Literal('2015-01-01')).eval(env) == datetime(2015, 1, 1) + assert Apply(Reference("str2date"), Literal('2015-01-01T18:32:57') + ).eval(env) == datetime(2015, 1, 1, 18, 32, 57) + assert Apply( + Reference("str2date"), Literal('2015-01-01T18:32:57.001200') + ).eval(env) == datetime(2015, 1, 1, 18, 32, 57) + assert Apply( + Reference("str2date"), Literal('2015-01-01T18:32:57.001200Z') + ).eval(env) == datetime(2015, 1, 1, 18, 32, 57) + assert Apply(Reference("str2date"), + Literal(u'日'.encode('utf8'))).eval(env) == None assert Apply(Reference("str2date"), Literal(u'日')).eval(env) == None - assert Apply(Reference("selected-at"), Literal('a b c'), Literal('1')).eval(env) == 'b' - assert Apply(Reference("selected-at"), Literal(u'a b 日'), Literal('-1')).eval(env) == u'日' - assert Apply(Reference("selected-at"), Literal('a b c'), Literal('5')).eval(env) is None - assert Apply(Reference("selected"), Literal('a b c'), Literal('b')).eval(env) is True - assert Apply(Reference("selected"), Literal(u'a b 日本'), Literal('d')).eval(env) is False - assert Apply(Reference("selected"), Literal(u'a bb 日本'), Literal('b')).eval(env) is False - assert Apply(Reference("selected"), Literal(u'a bb 日本'), Literal(u'日本')).eval(env) is True - assert Apply(Reference("join"), Literal('.'), Literal('a'), Literal('b'), Literal('c')).eval(env) == 'a.b.c' - assert Apply(Reference("default"), Literal(None), Literal('a')).eval(env) == 'a' - assert Apply(Reference("default"), Literal('b'), Literal('a')).eval(env) == 'b' - assert Apply(Reference("count-selected"), Literal(u'a bb 日本')).eval(env) == 3 - assert Apply(Reference("sha1"), Literal(u'a bb 日本')).eval(env) == 'e25a54025417b06d88d40baa8c71f6eee9c07fb1' - assert Apply(Reference("sha1"), Literal(b'2015')).eval(env) == '9cdda67ded3f25811728276cefa76b80913b4c54' - assert Apply(Reference("sha1"), Literal(2015)).eval(env) == '9cdda67ded3f25811728276cefa76b80913b4c54' + assert Apply(Reference("format-uuid"), + Literal(0xf00)).eval(env) == None + assert Apply(Reference("format-uuid"), + Literal('f00')).eval(env) == None + assert Apply( + Reference("format-uuid"), + Literal('00a3e019-4ce1-4587-94c5-0971dee2de22') + ).eval(env) == '00a3e019-4ce1-4587-94c5-0971dee2de22' + assert Apply(Reference("selected-at"), Literal('a b c'), + Literal('1')).eval(env) == 'b' + assert Apply( + Reference("selected-at"), Literal(u'a b 日'), Literal('-1') + ).eval(env) == u'日' + assert Apply(Reference("selected-at"), Literal('a b c'), + Literal('5')).eval(env) is None + assert Apply(Reference("selected"), Literal('a b c'), + Literal('b')).eval(env) is True + assert Apply(Reference("selected"), Literal(u'a b 日本'), + Literal('d')).eval(env) is False + assert Apply(Reference("selected"), Literal(u'a bb 日本'), + Literal('b')).eval(env) is False + assert Apply( + Reference("selected"), Literal(u'a bb 日本'), Literal(u'日本') + ).eval(env) is True + assert Apply( + Reference("join"), Literal('.'), Literal('a'), Literal('b'), + Literal('c') + ).eval(env) == 'a.b.c' + assert Apply(Reference("default"), Literal(None), + Literal('a')).eval(env) == 'a' + assert Apply(Reference("default"), Literal('b'), + Literal('a')).eval(env) == 'b' + assert Apply(Reference("count-selected"), + Literal(u'a bb 日本')).eval(env) == 3 + assert Apply(Reference("sha1"), Literal(u'a bb 日本') + ).eval(env) == 'e25a54025417b06d88d40baa8c71f6eee9c07fb1' + assert Apply(Reference("sha1"), Literal(b'2015') + ).eval(env) == '9cdda67ded3f25811728276cefa76b80913b4c54' + assert Apply(Reference("sha1"), Literal(2015) + ).eval(env) == '9cdda67ded3f25811728276cefa76b80913b4c54' def test_or(self): env = BuiltInEnv() assert Apply(Reference("or"), Literal(None), Literal(2)).eval(env) == 2 - laziness_iterator = RepeatableIterator(lambda: (i if i < 1 else die('Not lazy enough') for i in range(2))) - assert Apply(Reference("or"), Literal(1), Literal(laziness_iterator)).eval(env) == 1 - assert Apply(Reference("or"), Literal(''), Literal(laziness_iterator)).eval(env) == '' - assert Apply(Reference("or"), Literal(0), Literal(laziness_iterator)).eval(env) == 0 + laziness_iterator = RepeatableIterator( + lambda: (i if i < 1 else die('Not lazy enough') for i in range(2)) + ) + assert Apply(Reference("or"), Literal(1), + Literal(laziness_iterator)).eval(env) == 1 + assert Apply(Reference("or"), Literal(''), + Literal(laziness_iterator)).eval(env) == '' + assert Apply(Reference("or"), Literal(0), + Literal(laziness_iterator)).eval(env) == 0 with pytest.raises(LazinessException): - Apply(Reference("or"), Literal(None), Literal(laziness_iterator)).eval(env) + Apply(Reference("or"), Literal(None), + Literal(laziness_iterator)).eval(env) env = env | JsonPathEnv({'a': {'c': 'c val'}}) - assert Apply(Reference("or"), Reference('a.b'), Reference('a.c')).eval(env) == 'c val' - assert Apply(Reference("or"), Reference('a.b'), Reference('a.d')).eval(env) is None + assert Apply(Reference("or"), Reference('a.b'), + Reference('a.c')).eval(env) == 'c val' + assert Apply(Reference("or"), Reference('a.b'), + Reference('a.d')).eval(env) is None + + env = env.replace({'a': [], 'b': [1, 2], 'c': 2}) + self.check_case( + Apply(Reference("or"), Reference('a.[*]'), + Reference('b')).eval(env), [1, 2] + ) + self.check_case( + Apply(Reference("or"), Reference('b.[*]'), + Reference('c')).eval(env), [1, 2] + ) + self.check_case( + Apply(Reference("or"), Reference('a.[*]'), + Reference('$')).eval(env), { + 'a': [], + 'b': [1, 2], + 'c': 2, + 'id': '$' + } + ) def test_attachment_url(self): - env = BuiltInEnv({'commcarehq_base_url': 'https://www.commcarehq.org'}) | JsonPathEnv({'id': '123', 'domain': 'd1', 'photo': 'a.jpg'}) + env = BuiltInEnv({'commcarehq_base_url': 'https://www.commcarehq.org'} + ) | JsonPathEnv({ + 'id': '123', + 'domain': 'd1', + 'photo': 'a.jpg' + }) expected = 'https://www.commcarehq.org/a/d1/api/form/attachment/123/a.jpg' - assert Apply(Reference('attachment_url'), Reference('photo')).eval(env) == expected + assert Apply(Reference('attachment_url'), + Reference('photo')).eval(env) == expected def test_attachment_url_repeat(self): - env = BuiltInEnv({'commcarehq_base_url': 'https://www.commcarehq.org'}) | JsonPathEnv({ - 'id': '123', 'domain': 'd1', 'repeat': [ - {'photo': 'a.jpg'}, {'photo': 'b.jpg'} - ] - }) + env = BuiltInEnv({'commcarehq_base_url': 'https://www.commcarehq.org'} + ) | JsonPathEnv({ + 'id': '123', + 'domain': 'd1', + 'repeat': [{ + 'photo': 'a.jpg' + }, { + 'photo': 'b.jpg' + }] + }) expected = [ 'https://www.commcarehq.org/a/d1/api/form/attachment/123/a.jpg', 'https://www.commcarehq.org/a/d1/api/form/attachment/123/b.jpg', ] - result = unwrap_val(Map( - source=Reference('repeat.[*]'), - body=Apply(Reference('attachment_url'), Reference('photo')) - ).eval(env)) + result = unwrap_val( + Map( + source=Reference('repeat.[*]'), + body=Apply(Reference('attachment_url'), Reference('photo')) + ).eval(env) + ) assert result == expected + def test_form_url(self): + env = BuiltInEnv({'commcarehq_base_url': 'https://www.commcarehq.org'} + ) | JsonPathEnv({ + 'id': '123', + 'domain': 'd1' + }) + expected = 'https://www.commcarehq.org/a/d1/reports/form_data/123/' + assert Apply(Reference('form_url'), + Reference('id')).eval(env) == expected + + def test_case_url(self): + env = BuiltInEnv({'commcarehq_base_url': 'https://www.commcarehq.org'} + ) | JsonPathEnv({ + 'id': '123', + 'domain': 'd1' + }) + expected = 'https://www.commcarehq.org/a/d1/reports/case_data/123/' + assert Apply(Reference('case_url'), + Reference('id')).eval(env) == expected + + def test_unique(self): + env = BuiltInEnv() | JsonPathEnv({ + "list": [{ + "a": 1 + }, { + "a": 2 + }, { + "a": 3 + }, { + "a": 2 + }] + }) + assert Apply(Reference('unique'), + Reference('list[*].a')).eval(env) == [1, 2, 3] + def test_template(self): env = BuiltInEnv() | JsonPathEnv({'a': '1', 'b': '2'}) - assert Apply(Reference('template'), Literal('{}.{}'), Reference('a'), Reference('b')).eval(env) == '1.2' + assert Apply( + Reference('template'), Literal('{}.{}'), Reference('a'), + Reference('b') + ).eval(env) == '1.2' def test_substr(self): - env = BuiltInEnv({'single_byte_chars': u'abcdefghijklmnopqrstuvwxyz', - 'multi_byte_chars': u'αβγδεζηθικλμνξοπρςστυφχψω', - 'an_integer': 123456 + env = BuiltInEnv({ + 'single_byte_chars': u'abcdefghijklmnopqrstuvwxyz', + 'multi_byte_chars': u'αβγδεζηθικλμνξοπρςστυφχψω', + 'an_integer': 123456 }) - assert Apply(Reference('substr'), Reference('single_byte_chars'), - Literal(-4), Literal(30)).eval(env) == None - assert Apply(Reference('substr'), Reference('single_byte_chars'), - Literal(0), Literal(26)).eval(env) == u'abcdefghijklmnopqrstuvwxyz' - assert Apply(Reference('substr'), Reference('single_byte_chars'), - Literal(10), Literal(16)).eval(env) == u'klmnop' - assert Apply(Reference('substr'), Reference('single_byte_chars'), - Literal(13), Literal(14)).eval(env) == u'n' - assert Apply(Reference('substr'), Reference('single_byte_chars'), - Literal(13), Literal(13)).eval(env) == u'' - assert Apply(Reference('substr'), Reference('single_byte_chars'), - Literal(14), Literal(13)).eval(env) == u'' - assert Apply(Reference('substr'), Reference('single_byte_chars'), - Literal(5), Literal(-1)).eval(env) == None - - assert Apply(Reference('substr'), Reference('multi_byte_chars'), - Literal(-4), Literal(30)).eval(env) == None - assert Apply(Reference('substr'), Reference('multi_byte_chars'), - Literal(0), Literal(25)).eval(env) == u'αβγδεζηθικλμνξοπρςστυφχψω' - assert Apply(Reference('substr'), Reference('multi_byte_chars'), - Literal(10), Literal(15)).eval(env) == u'λμνξο' - assert Apply(Reference('substr'), Reference('multi_byte_chars'), - Literal(13), Literal(14)).eval(env) == u'ξ' - assert Apply(Reference('substr'), Reference('multi_byte_chars'), - Literal(13), Literal(12)).eval(env) == u'' - assert Apply(Reference('substr'), Reference('multi_byte_chars'), - Literal(14), Literal(13)).eval(env) == u'' - assert Apply(Reference('substr'), Reference('multi_byte_chars'), - Literal(5), Literal(-1)).eval(env) == None - - assert Apply(Reference('substr'), Reference('an_integer'), - Literal(-1), Literal(3)).eval(env) == None - assert Apply(Reference('substr'), Reference('an_integer'), - Literal(0), Literal(6)).eval(env) == u'123456' - assert Apply(Reference('substr'), Reference('an_integer'), - Literal(2), Literal(4)).eval(env) == u'34' - assert Apply(Reference('substr'), Reference('an_integer'), - Literal(4), Literal(2)).eval(env) == u'' - assert Apply(Reference('substr'), Reference('an_integer'), - Literal(5), Literal(-1)).eval(env) == None + assert Apply( + Reference('substr'), Reference('single_byte_chars'), Literal(-4), + Literal(30) + ).eval(env) == None + assert Apply( + Reference('substr'), Reference('single_byte_chars'), Literal(0), + Literal(26) + ).eval(env) == u'abcdefghijklmnopqrstuvwxyz' + assert Apply( + Reference('substr'), Reference('single_byte_chars'), Literal(10), + Literal(16) + ).eval(env) == u'klmnop' + assert Apply( + Reference('substr'), Reference('single_byte_chars'), Literal(13), + Literal(14) + ).eval(env) == u'n' + assert Apply( + Reference('substr'), Reference('single_byte_chars'), Literal(13), + Literal(13) + ).eval(env) == u'' + assert Apply( + Reference('substr'), Reference('single_byte_chars'), Literal(14), + Literal(13) + ).eval(env) == u'' + assert Apply( + Reference('substr'), Reference('single_byte_chars'), Literal(5), + Literal(-1) + ).eval(env) == None + + assert Apply( + Reference('substr'), Reference('multi_byte_chars'), Literal(-4), + Literal(30) + ).eval(env) == None + assert Apply( + Reference('substr'), Reference('multi_byte_chars'), Literal(0), + Literal(25) + ).eval(env) == u'αβγδεζηθικλμνξοπρςστυφχψω' + assert Apply( + Reference('substr'), Reference('multi_byte_chars'), Literal(10), + Literal(15) + ).eval(env) == u'λμνξο' + assert Apply( + Reference('substr'), Reference('multi_byte_chars'), Literal(13), + Literal(14) + ).eval(env) == u'ξ' + assert Apply( + Reference('substr'), Reference('multi_byte_chars'), Literal(13), + Literal(12) + ).eval(env) == u'' + assert Apply( + Reference('substr'), Reference('multi_byte_chars'), Literal(14), + Literal(13) + ).eval(env) == u'' + assert Apply( + Reference('substr'), Reference('multi_byte_chars'), Literal(5), + Literal(-1) + ).eval(env) == None + + assert Apply( + Reference('substr'), Reference('an_integer'), Literal(-1), + Literal(3) + ).eval(env) == None + assert Apply( + Reference('substr'), Reference('an_integer'), Literal(0), + Literal(6) + ).eval(env) == u'123456' + assert Apply( + Reference('substr'), Reference('an_integer'), Literal(2), + Literal(4) + ).eval(env) == u'34' + assert Apply( + Reference('substr'), Reference('an_integer'), Literal(4), + Literal(2) + ).eval(env) == u'' + assert Apply( + Reference('substr'), Reference('an_integer'), Literal(5), + Literal(-1) + ).eval(env) == None def test_map(self): env = BuiltInEnv() | DictEnv({}) - laziness_iterator = RepeatableIterator(lambda: ({'a':i} if i < 5 else die('Not lazy enough') for i in range(12))) - - assert list(Map(source=Literal([{'a':1}, {'a':2}, {'a':3}]), body=Literal(1)).eval(env)) == [1, 1, 1] - assert list(Map(source=Literal([{'a':1}, {'a':2}, {'a':3}]), body=Reference('a')).eval(env)) == [1, 2, 3] - - assert list(islice(Map(source=Literal(laziness_iterator), body=Reference('a')).eval(env), 5)) == [0, 1, 2, 3, 4] + laziness_iterator = RepeatableIterator( + lambda: ({ + 'a': i + } if i < 5 else die('Not lazy enough') for i in range(12)) + ) + + assert list( + Map( + source=Literal([{ + 'a': 1 + }, { + 'a': 2 + }, { + 'a': 3 + }]), + body=Literal(1) + ).eval(env) + ) == [1, 1, 1] + assert list( + Map( + source=Literal([{ + 'a': 1 + }, { + 'a': 2 + }, { + 'a': 3 + }]), + body=Reference('a') + ).eval(env) + ) == [1, 2, 3] + + assert list( + islice( + Map(source=Literal(laziness_iterator), + body=Reference('a')).eval(env), 5 + ) + ) == [0, 1, 2, 3, 4] try: - list(Map(source=Literal(laziness_iterator), body=Reference('a')).eval(env)) + list( + Map(source=Literal(laziness_iterator), + body=Reference('a')).eval(env) + ) raise Exception('Should have failed') except LazinessException: pass @@ -197,61 +558,144 @@ def test_map(self): def test_flatmap(self): env = BuiltInEnv() | DictEnv({}) - laziness_iterator = RepeatableIterator(lambda: ({'a':range(i)} if i < 4 else die('Not lazy enough') for i in range(12))) - - assert list(FlatMap(source=Literal([{'a':[1]}, {'a':'foo'}, {'a':[3, 4]}]), body=Literal([1, 2, 3])).eval(env)) == [1, 2, 3, 1, 2, 3, 1, 2, 3] - assert list(FlatMap(source=Literal([{'a':[1]}, {'a':[2]}, {'a':[3, 4]}]), body=Reference('a')).eval(env)) == [1, 2, 3, 4] - - assert list(islice(FlatMap(source=Literal(laziness_iterator), body=Reference('a')).eval(env), 6)) == [0, - 0, 1, - 0, 1, 2] + laziness_iterator = RepeatableIterator( + lambda: ({ + 'a': range(i) + } if i < 4 else die('Not lazy enough') for i in range(12)) + ) + + assert list( + FlatMap( + source=Literal([{ + 'a': [1] + }, { + 'a': 'foo' + }, { + 'a': [3, 4] + }]), + body=Literal([1, 2, 3]) + ).eval(env) + ) == [1, 2, 3, 1, 2, 3, 1, 2, 3] + assert list( + FlatMap( + source=Literal([{ + 'a': [1] + }, { + 'a': [2] + }, { + 'a': [3, 4] + }]), + body=Reference('a') + ).eval(env) + ) == [1, 2, 3, 4] + + assert list( + islice( + FlatMap( + source=Literal(laziness_iterator), body=Reference('a') + ).eval(env), 6 + ) + ) == [0, 0, 1, 0, 1, 2] try: - list(FlatMap(source=Literal(laziness_iterator), body=Reference('a')).eval(env)) + list( + FlatMap( + source=Literal(laziness_iterator), body=Reference('a') + ).eval(env) + ) raise Exception('Should have failed') except LazinessException: pass + def _setup_emit_test(self, emitter_env): + env = BuiltInEnv() | JsonPathEnv({ + 'foo': { + 'baz': 3, + 'bar': True, + 'boo': None + } + }) | emitter_env + Emit( + table='Foo', + headings=[Literal('foo')], + source=List([ + List([ + Reference('foo.baz'), + Reference('foo.bar'), + Reference('foo.foo'), + Reference('foo.boo') + ]) + ]), + missing_value='---' + ).eval(env) + def test_emit(self): writer = JValueTableWriter() - env = BuiltInEnv() | JsonPathEnv({'foo': {'baz': 3, 'bar': True, 'boo': None}}) | EmitterEnv(writer) - Emit(table='Foo', - headings=[Literal('foo')], - source=List([ - List([ Reference('foo.baz'), Reference('foo.bar'), Reference('foo.foo'), Reference('foo.boo') ]) - ]), - missing_value='---').eval(env) - + self._setup_emit_test(EmitterEnv(writer)) assert list(writer.tables['Foo'].rows) == [[3, True, '---', None]] + def test_emit_generator(self): + + class TestWriter(JValueTableWriter): + + def write_table(self, table): + self.tables[table.name] = table + + writer = TestWriter() + self._setup_emit_test(EmitterEnv(writer)) + assert isinstance( + writer.tables['Foo'].rows, (map, filter, types.GeneratorType) + ) + + def test_emit_env_generator(self): + + class TestEmitterEnv(EmitterEnv): + + def emit_table(self, table_spec): + self.table = table_spec + + env = TestEmitterEnv(JValueTableWriter()) + self._setup_emit_test(env) + assert isinstance(env.table.rows, (map, filter, types.GeneratorType)) + def test_emit_multi_same_query(self): - """Test that we can emit multiple tables from the same set of source data. - This is useful if you need to generate multiple tables from the same datasource. + """ + Test that we can emit multiple tables from the same set of + source data. This is useful if you need to generate multiple + tables from the same datasource. """ writer = JValueTableWriter() env = BuiltInEnv() | JsonPathEnv() | EmitterEnv(writer) result = Map( source=Literal([ - {'foo': {'baz': 3, 'bar': True, 'boo': None}}, - {'foo': {'baz': 4, 'bar': False, 'boo': 1}}, + { + 'foo': { + 'baz': 3, + 'bar': True, + 'boo': None + } + }, + { + 'foo': { + 'baz': 4, + 'bar': False, + 'boo': 1 + } + }, ]), body=List([ Emit( table='FooBaz', headings=[Literal('foo')], - source=List([ - List([ Reference('foo.baz')]) - ]), + source=List([List([Reference('foo.baz')])]), ), Emit( table='FooBar', headings=[Literal('foo')], - source=List([ - List([Reference('foo.bar')]) - ]), + source=List([List([Reference('foo.bar')])]), ) - ]), + ]), ).eval(env) # evaluate result @@ -261,10 +705,13 @@ def test_emit_multi_same_query(self): assert writer.tables['FooBar'].rows == [[True], [False]] def test_emit_mutli_different_query(self): - """Test that we can emit multiple tables from the same set of source data even - if the emitted table have different 'root doc' expressions. + """ + Test that we can emit multiple tables from the same set of + source data even if the emitted table have different 'root doc' + expressions. - Example use case could be emitting cases and case actions, or form data and repeats. + Example use case could be emitting cases and case actions, or + form data and repeats. """ writer = JValueTableWriter() env = BuiltInEnv() | JsonPathEnv() | EmitterEnv(writer) @@ -308,39 +755,77 @@ def test_emit_mutli_different_query(self): # evaluate result list(result) - print(writer.tables) assert writer.tables['t1'].rows == [['1'], ['2']] - assert writer.tables['t2'].rows == [['1', 3], ['1', 4], ['2', 5], ['2', 6]] + assert writer.tables['t2'].rows == [['1', 3], ['1', 4], ['2', 5], + ['2', 6]] def test_from_jvalue(self): - assert MiniLinq.from_jvalue({"Ref": "form.log_subreport"}) == Reference("form.log_subreport") - assert (MiniLinq.from_jvalue({"Apply": {"fn": {"Ref":"len"}, "args": [{"Ref": "form.log_subreport"}]}}) - == Apply(Reference("len"), Reference("form.log_subreport"))) - assert MiniLinq.from_jvalue([{"Ref": "form.log_subreport"}]) == [Reference("form.log_subreport")] + assert MiniLinq.from_jvalue({"Ref": "form.log_subreport"} + ) == Reference("form.log_subreport") + assert ( + MiniLinq.from_jvalue({ + "Apply": { + "fn": { + "Ref": "len" + }, + "args": [{ + "Ref": "form.log_subreport" + }] + } + }) == Apply(Reference("len"), Reference("form.log_subreport")) + ) + assert MiniLinq.from_jvalue([{ + "Ref": "form.log_subreport" + }]) == [Reference("form.log_subreport")] def test_filter(self): env = BuiltInEnv() | DictEnv({}) named = [{'n': n} for n in range(1, 5)] - assert list(Filter(Literal(named), Apply(Reference('>'), Reference('n'), Literal(2))).eval(env)) == [{'n': 3}, {'n': 4}] - assert list(Filter(Literal([1, 2, 3, 4]), Apply(Reference('>'), Reference('n'), Literal(2)), 'n').eval(env)) == [3, 4] + assert list( + Filter( + Literal(named), + Apply(Reference('>'), Reference('n'), Literal(2)) + ).eval(env) + ) == [{ + 'n': 3 + }, { + 'n': 4 + }] + assert list( + Filter( + Literal([1, 2, 3, 4]), + Apply(Reference('>'), Reference('n'), Literal(2)), 'n' + ).eval(env) + ) == [3, 4] def test_emit_table_unwrap_dicts(self): writer = JValueTableWriter() env = EmitterEnv(writer) - env.emit_table(TableSpec(**{ - 'name': 't1', - 'headings': ['a'], - 'rows':[ - ['hi'], - [{'#text': 'test_text','@case_type': 'person','@relationship': 'child','id': 'nothing'}], - [{'@case_type': '', '@relationship': 'child', 'id': 'some_id'}], - [{'t': 123}], - ] - })) - - writer.tables['t1'].rows = [ - ['hi'], - ['test_text'], - [''], - [{'t': 123}] - ] + env.emit_table( + TableSpec( + **{ + 'name': + 't1', + 'headings': ['a'], + 'rows': [ + ['hi'], + [{ + '#text': 'test_text', + '@case_type': 'person', + '@relationship': 'child', + 'id': 'nothing' + }], + [{ + '@case_type': '', + '@relationship': 'child', + 'id': 'some_id' + }], + [{ + 't': 123 + }], + ] + } + ) + ) + + writer.tables['t1'].rows = [['hi'], ['test_text'], [''], [{'t': 123}]] diff --git a/tests/test_misc.py b/tests/test_misc.py index 136e2476..47ffd73a 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes -import unittest import hashlib -import tempfile import struct +import tempfile +import unittest from commcare_export import misc @@ -11,15 +9,18 @@ class TestDigestFile(unittest.TestCase): def check_digest(self, contents): - with tempfile.NamedTemporaryFile(prefix='commcare-export-test-', mode='wb') as file: - file.write(contents) + with tempfile.NamedTemporaryFile( + prefix='commcare-export-test-', mode='wb' + ) as file: + file.write(contents) file.flush() file_digest = misc.digest_file(file.name) - assert file_digest == hashlib.md5(contents).hexdigest() # Make sure the chunking does not mess with stuff - + # Make sure the chunking does not mess with stuff + assert file_digest == hashlib.md5(contents).hexdigest() + def test_digest_file_ascii(self): - self.check_digest('Hello'.encode('utf-8')) # Even a call to `write` requires encoding (as it should) in Python 3 + self.check_digest('Hello'.encode('utf-8')) def test_digest_file_long(self): self.check_digest(('Hello' * 100000).encode('utf-8')) diff --git a/tests/test_paginator.py b/tests/test_paginator.py new file mode 100644 index 00000000..e3b3e7f9 --- /dev/null +++ b/tests/test_paginator.py @@ -0,0 +1,33 @@ +import unittest + +from commcare_export.checkpoint import CheckpointManagerWithDetails +from commcare_export.commcare_minilinq import ( + DEFAULT_UCR_PAGE_SIZE, + PaginationMode, + get_paginator, +) + + +class PaginatorTest(unittest.TestCase): + def test_ucr_paginator_page_size(self): + checkpoint_manager = CheckpointManagerWithDetails( + None, None, PaginationMode.cursor + ) + paginator = get_paginator( + resource="ucr", + pagination_mode=checkpoint_manager.pagination_mode) + paginator.init() + initial_params = paginator.next_page_params_since( + checkpoint_manager.since_param + ) + self.assertEqual(initial_params["limit"], DEFAULT_UCR_PAGE_SIZE) + + paginator = get_paginator( + resource="ucr", + page_size=1, + pagination_mode=checkpoint_manager.pagination_mode) + paginator.init() + initial_params = paginator.next_page_params_since( + checkpoint_manager.since_param + ) + self.assertEqual(initial_params["limit"], 1) \ No newline at end of file diff --git a/tests/test_repeatable_iterator.py b/tests/test_repeatable_iterator.py index 58098223..89dc601d 100644 --- a/tests/test_repeatable_iterator.py +++ b/tests/test_repeatable_iterator.py @@ -1,9 +1,9 @@ -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes -from itertools import * import unittest +from itertools import * from commcare_export.repeatable_iterator import RepeatableIterator + class TestRepeatableIterator(unittest.TestCase): @classmethod @@ -12,10 +12,11 @@ def setup_class(cls): def test_iteration(self): - class LazinessException(Exception): pass + class LazinessException(Exception): + pass - def test1(): - for i in range(1, 100): + def test1(): + for i in range(1, 100): yield i def test2(): @@ -24,12 +25,12 @@ def test2(): raise LazinessException('Not lazy enough') yield i - # First make sure that we've properly set up a situation that fails - # without RepeatableIterator + # First make sure that we've properly set up a situation that + # fails without RepeatableIterator iterator = test1() assert list(iterator) == list(range(1, 100)) assert list(iterator) == [] - + # Now test that the RepeatableIterator restores functionality iterator = RepeatableIterator(test1) assert list(iterator) == list(range(1, 100)) diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 00000000..c6559fe2 --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,19 @@ +import pytest + +from commcare_export.version import parse_version + + +@pytest.mark.parametrize( + "input,output", + [ + ("1.2.3", "1.2.3"), + ("1.2", "1.2"), + ("0.1.5-3", "0.1.5.dev3"), + ("0.1.5-3-g1234567", "0.1.5.dev3"), + ("0.1.5-4-g1234567-dirty", "0.1.5.dev4"), + ("0.1.5-15-g1234567-dirty-123", "0.1.5.dev15"), + ("a.b.c", "a.b.c"), + ] +) +def test_parse_version(input, output): + assert parse_version(input) == output diff --git a/tests/test_writers.py b/tests/test_writers.py index b770af3d..102fd00d 100644 --- a/tests/test_writers.py +++ b/tests/test_writers.py @@ -1,18 +1,20 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals, print_function, absolute_import, division, generators, nested_scopes - -import csv342 as csv +import csv import datetime import io import tempfile import zipfile import openpyxl -import pytest import sqlalchemy +import pytest from commcare_export.specs import TableSpec -from commcare_export.writers import SqlTableWriter, JValueTableWriter, Excel2007TableWriter, CsvTableWriter +from commcare_export.writers import ( + CsvTableWriter, + Excel2007TableWriter, + JValueTableWriter, + SqlTableWriter, +) @pytest.fixture() @@ -22,7 +24,11 @@ def writer(db_params): @pytest.fixture() def strict_writer(db_params): - return SqlTableWriter(db_params['url'], poolclass=sqlalchemy.pool.NullPool, strict_types=True) + return SqlTableWriter( + db_params['url'], + poolclass=sqlalchemy.pool.NullPool, + strict_types=True + ) TYPE_MAP = { @@ -33,97 +39,127 @@ def strict_writer(db_params): class TestWriters(object): + def test_JValueTableWriter(self): writer = JValueTableWriter() - writer.write_table(TableSpec(**{ - 'name': 'foo', - 'headings': ['a', 'bjørn', 'c', 'd'], - 'rows': [ - [1, '2', 3, datetime.date(2015, 1, 1)], - [4, '日本', 6, datetime.date(2015, 1, 2)], - ] - })) - - writer.write_table(TableSpec(**{ - 'name': 'foo', - 'headings': ['a', 'bjørn', 'c', 'd'], - 'rows': [ - [5, 'bob', 9, datetime.date(2018, 1, 2)], - ] - })) + writer.write_table( + TableSpec( + **{ + 'name': + 'foo', + 'headings': ['a', 'bjørn', 'c', 'd'], + 'rows': [ + [1, '2', 3, datetime.date(2015, 1, 1)], + [4, '日本', 6, datetime.date(2015, 1, 2)], + ] + } + ) + ) + + writer.write_table( + TableSpec( + **{ + 'name': 'foo', + 'headings': ['a', 'bjørn', 'c', 'd'], + 'rows': [[5, 'bob', 9, + datetime.date(2018, 1, 2)],] + } + ) + ) assert writer.tables == { - 'foo': TableSpec(**{ - 'name': 'foo', - 'headings': ['a', 'bjørn', 'c', 'd'], - 'rows': [ - [1, '2', 3, '2015-01-01'], - [4, '日本', 6, '2015-01-02'], - [5, 'bob', 9, '2018-01-02'], - ], - }) + 'foo': + TableSpec( + **{ + 'name': + 'foo', + 'headings': ['a', 'bjørn', 'c', 'd'], + 'rows': [ + [1, '2', 3, '2015-01-01'], + [4, '日本', 6, '2015-01-02'], + [5, 'bob', 9, '2018-01-02'], + ], + } + ) } def test_Excel2007TableWriter(self): with tempfile.NamedTemporaryFile(suffix='.xlsx') as file: with Excel2007TableWriter(file=file) as writer: - writer.write_table(TableSpec(**{ - 'name': 'foo', - 'headings': ['a', 'bjørn', 'c'], - 'rows': [ - [1, '2', 3], - [4, '日本', 6], - ] - })) + writer.write_table( + TableSpec( + **{ + 'name': 'foo', + 'headings': ['a', 'bjørn', 'c'], + 'rows': [ + [1, '2', 3], + [4, '日本', 6], + ] + } + ) + ) self._check_Excel2007TableWriter_output(file.name) def test_Excel2007TableWriter_write_mutli(self): with tempfile.NamedTemporaryFile(suffix='.xlsx') as file: with Excel2007TableWriter(file=file) as writer: - writer.write_table(TableSpec(**{ - 'name': 'foo', - 'headings': ['a', 'bjørn', 'c'], - 'rows': [ - [1, '2', 3], - ] - })) - - writer.write_table(TableSpec(**{ - 'name': 'foo', - 'headings': ['a', 'bjørn', 'c'], - 'rows': [ - [4, '日本', 6], - ] - })) + writer.write_table( + TableSpec( + **{ + 'name': 'foo', + 'headings': ['a', 'bjørn', 'c'], + 'rows': [[1, '2', 3],] + } + ) + ) + + writer.write_table( + TableSpec( + **{ + 'name': 'foo', + 'headings': ['a', 'bjørn', 'c'], + 'rows': [[4, '日本', 6],] + } + ) + ) self._check_Excel2007TableWriter_output(file.name) def _check_Excel2007TableWriter_output(self, filename): - output_wb = openpyxl.load_workbook(filename) - - assert output_wb.sheetnames == ['foo'] - foo_sheet = output_wb['foo'] - assert [ [cell.value for cell in row] for row in foo_sheet['A1:C3']] == [ - ['a', 'bjørn', 'c'], - ['1', '2', '3'], # Note how pyxl does some best-effort parsing to *whatever* type - ['4', '日本', '6'], - ] + output_wb = openpyxl.load_workbook(filename) + + assert output_wb.sheetnames == ['foo'] + foo_sheet = output_wb['foo'] + assert [ + [cell.value for cell in row] for row in foo_sheet['A1:C3'] + ] == [ + ['a', 'bjørn', 'c'], + ['1', '2', '3' + ], # Note how pyxl does some best-effort parsing to *whatever* type + ['4', '日本', '6'], + ] def test_CsvTableWriter(self): with tempfile.NamedTemporaryFile() as file: with CsvTableWriter(file=file) as writer: - writer.write_table(TableSpec(**{ - 'name': 'foo', - 'headings': ['a', 'bjørn', 'c'], - 'rows': [ - [1, '2', 3], - [4, '日本', 6], - ] - })) + writer.write_table( + TableSpec( + **{ + 'name': 'foo', + 'headings': ['a', 'bjørn', 'c'], + 'rows': [ + [1, '2', 3], + [4, '日本', 6], + ] + } + ) + ) with zipfile.ZipFile(file.name, 'r') as output_zip: with output_zip.open('foo.csv') as csv_file: - output = csv.reader(io.TextIOWrapper(csv_file, encoding='utf-8')) + output = csv.reader( + io.TextIOWrapper(csv_file, encoding='utf-8') + ) assert [row for row in output] == [ ['a', 'bjørn', 'c'], @@ -134,11 +170,13 @@ def test_CsvTableWriter(self): @pytest.mark.dbtest class TestSQLWriters(object): + def _type_convert(self, connection, row): """ - Different databases store and return values differently so convert the values - in the expected row to match the DB. + Different databases store and return values differently so + convert the values in the expected row to match the DB. """ + def convert(type_map, value): func = type_map.get(value.__class__, None) return func(value) if func else value @@ -151,129 +189,233 @@ def convert(type_map, value): def test_insert(self, writer): with writer: - writer.write_table(TableSpec(**{ - 'name': 'foo_insert', - 'headings': ['id', 'a', 'b', 'c'], - 'rows': [ - ['bizzle', 1, 2, 3], - ['bazzle', 4, 5, 6], - ] - })) - - # We can use raw SQL instead of SqlAlchemy expressions because we built the DB above + writer.write_table( + TableSpec( + **{ + 'name': 'foo_insert', + 'headings': ['id', 'a', 'b', 'c'], + 'rows': [ + ['bizzle', 1, 2, 3], + ['bazzle', 4, 5, 6], + ] + } + ) + ) + + # We can use raw SQL instead of SqlAlchemy expressions because + # we built the DB above with writer: - result = dict([(row['id'], row) for row in writer.connection.execute('SELECT id, a, b, c FROM foo_insert')]) + result = dict([(row['id'], row) for row in writer.connection + .execute('SELECT id, a, b, c FROM foo_insert')]) assert len(result) == 2 - assert dict(result['bizzle']) == {'id': 'bizzle', 'a': 1, 'b': 2, 'c': 3} - assert dict(result['bazzle']) == {'id': 'bazzle', 'a': 4, 'b': 5, 'c': 6} + assert dict(result['bizzle']) == { + 'id': 'bizzle', + 'a': 1, + 'b': 2, + 'c': 3 + } + assert dict(result['bazzle']) == { + 'id': 'bazzle', + 'a': 4, + 'b': 5, + 'c': 6 + } def test_upsert(self, writer): with writer: - writer.write_table(TableSpec(**{ - 'name': 'foo_upsert', - 'headings': ['id', 'a', 'b', 'c'], - 'rows': [ - ['zing', 3, None, 5] - ] - })) + writer.write_table( + TableSpec( + **{ + 'name': 'foo_upsert', + 'headings': ['id', 'a', 'b', 'c'], + 'rows': [['zing', 3, None, 5]] + } + ) + ) # don't select column 'b' since it hasn't been created yet with writer: - result = dict([(row['id'], row) for row in writer.connection.execute('SELECT id, a, c FROM foo_upsert')]) + result = dict([ + (row['id'], row) for row in + writer.connection.execute('SELECT id, a, c FROM foo_upsert') + ]) assert len(result) == 1 assert dict(result['zing']) == {'id': 'zing', 'a': 3, 'c': 5} with writer: - writer.write_table(TableSpec(**{ - 'name': 'foo_upsert', - 'headings': ['id', 'a', 'b', 'c'], - 'rows': [ - ['bizzle', 1, 'yo', 3], - ['bazzle', 4, '日本', 6], - ] - })) - - # We can use raw SQL instead of SqlAlchemy expressions because we built the DB above + writer.write_table( + TableSpec( + **{ + 'name': + 'foo_upsert', + 'headings': ['id', 'a', 'b', 'c'], + 'rows': [ + ['bizzle', 1, 'yo', 3], + ['bazzle', 4, '日本', 6], + ] + } + ) + ) + + # We can use raw SQL instead of SqlAlchemy expressions because + # we built the DB above with writer: - result = dict([(row['id'], row) for row in writer.connection.execute('SELECT id, a, b, c FROM foo_upsert')]) + result = dict([(row['id'], row) for row in writer.connection + .execute('SELECT id, a, b, c FROM foo_upsert')]) assert len(result) == 3 - assert dict(result['bizzle']) == {'id': 'bizzle', 'a': 1, 'b': 'yo', 'c': 3} - assert dict(result['bazzle']) == {'id': 'bazzle', 'a': 4, 'b': '日本', 'c': 6} + assert dict(result['bizzle']) == { + 'id': 'bizzle', + 'a': 1, + 'b': 'yo', + 'c': 3 + } + assert dict(result['bazzle']) == { + 'id': 'bazzle', + 'a': 4, + 'b': '日本', + 'c': 6 + } with writer: - writer.write_table(TableSpec(**{ - 'name': 'foo_upsert', - 'headings': ['id', 'a', 'b', 'c'], - 'rows': [ - ['bizzle', 7, '本', 9], - ] - })) - - # We can use raw SQL instead of SqlAlchemy expressions because we built the DB above + writer.write_table( + TableSpec( + **{ + 'name': 'foo_upsert', + 'headings': ['id', 'a', 'b', 'c'], + 'rows': [['bizzle', 7, '本', 9],] + } + ) + ) + + # We can use raw SQL instead of SqlAlchemy expressions because + # we built the DB above with writer: - result = dict([(row['id'], row) for row in writer.connection.execute('SELECT id, a, b, c FROM foo_upsert')]) + result = dict([(row['id'], row) for row in writer.connection + .execute('SELECT id, a, b, c FROM foo_upsert')]) assert len(result) == 3 - assert dict(result['bizzle']) == {'id': 'bizzle', 'a': 7, 'b': '本', 'c': 9} - assert dict(result['bazzle']) == {'id': 'bazzle', 'a': 4, 'b': '日本', 'c': 6} + assert dict(result['bizzle']) == { + 'id': 'bizzle', + 'a': 7, + 'b': '本', + 'c': 9 + } + assert dict(result['bazzle']) == { + 'id': 'bazzle', + 'a': 4, + 'b': '日本', + 'c': 6 + } def test_types(self, writer): self._test_types(writer, 'foo_fancy_types') def _test_types(self, writer, table_name): with writer: - writer.write_table(TableSpec(**{ - 'name': table_name, - 'headings': ['id', 'a', 'b', 'c', 'd', 'e'], - 'rows': [ - ['bizzle', 1, 'yo', True, datetime.date(2015, 1, 1), datetime.datetime(2014, 4, 2, 18, 56, 12)], - ['bazzle', 4, '日本', False, datetime.date(2015, 1, 2), datetime.datetime(2014, 5, 1, 11, 16, 45)], - ] - })) - - # We can use raw SQL instead of SqlAlchemy expressions because we built the DB above + writer.write_table( + TableSpec( + **{ + 'name': + table_name, + 'headings': ['id', 'a', 'b', 'c', 'd', 'e'], + 'rows': [ + [ + 'bizzle', 1, 'yo', True, + datetime.date(2015, 1, 1), + datetime.datetime(2014, 4, 2, 18, 56, 12) + ], + [ + 'bazzle', 4, '日本', False, + datetime.date(2015, 1, 2), + datetime.datetime(2014, 5, 1, 11, 16, 45) + ], + ] + } + ) + ) + + # We can use raw SQL instead of SqlAlchemy expressions because + # we built the DB above with writer: connection = writer.connection - result = dict( - [(row['id'], row) for row in connection.execute('SELECT id, a, b, c, d, e FROM %s' % table_name)]) + result = dict([ + (row['id'], row) for row in connection + .execute('SELECT id, a, b, c, d, e FROM %s' % table_name) + ]) assert len(result) == 2 expected = { - 'bizzle': {'id': 'bizzle', 'a': 1, 'b': 'yo', 'c': True, - 'd': datetime.date(2015, 1, 1), 'e': datetime.datetime(2014, 4, 2, 18, 56, 12)}, - 'bazzle': {'id': 'bazzle', 'a': 4, 'b': '日本', 'c': False, - 'd': datetime.date(2015, 1, 2), 'e': datetime.datetime(2014, 5, 1, 11, 16, 45)} + 'bizzle': { + 'id': 'bizzle', + 'a': 1, + 'b': 'yo', + 'c': True, + 'd': datetime.date(2015, 1, 1), + 'e': datetime.datetime(2014, 4, 2, 18, 56, 12) + }, + 'bazzle': { + 'id': 'bazzle', + 'a': 4, + 'b': '日本', + 'c': False, + 'd': datetime.date(2015, 1, 2), + 'e': datetime.datetime(2014, 5, 1, 11, 16, 45) + } } for id, row in result.items(): assert id in expected - assert dict(row) == self._type_convert(connection, expected[id]) + assert dict(row + ) == self._type_convert(connection, expected[id]) def test_change_type(self, writer): self._test_types(writer, 'foo_fancy_type_changes') with writer: - writer.write_table(TableSpec(**{ - 'name': 'foo_fancy_type_changes', - 'headings': ['id', 'a', 'b', 'c', 'd', 'e'], - 'rows': [ - ['bizzle', 'yo dude', '本', 'true', datetime.datetime(2015, 2, 13), '2014-08-01T11:23:45:00.0000Z'], - ] - })) - - # We can use raw SQL instead of SqlAlchemy expressions because we built the DB above + writer.write_table( + TableSpec( + **{ + 'name': + 'foo_fancy_type_changes', + 'headings': ['id', 'a', 'b', 'c', 'd', 'e'], + 'rows': [[ + 'bizzle', 'yo dude', '本', 'true', + datetime.datetime(2015, 2, 13), + '2014-08-01T11:23:45:00.0000Z' + ],] + } + ) + ) + + # We can use raw SQL instead of SqlAlchemy expressions because + # we built the DB above with writer: - result = dict([(row['id'], row) for row in - writer.connection.execute('SELECT id, a, b, c, d, e FROM foo_fancy_type_changes')]) + result = dict([ + (row['id'], row) for row in writer.connection.execute( + 'SELECT id, a, b, c, d, e FROM foo_fancy_type_changes' + ) + ]) assert len(result) == 2 expected = { - 'bizzle': {'id': 'bizzle', 'a': 'yo dude', 'b': '本', 'c': 'true', - 'd': datetime.date(2015, 2, 13), 'e': '2014-08-01T11:23:45:00.0000Z'}, - 'bazzle': {'id': 'bazzle', 'a': '4', 'b': '日本', 'c': 'false', - 'd': datetime.date(2015, 1, 2), 'e': '2014-05-01 11:16:45'} + 'bizzle': { + 'id': 'bizzle', + 'a': 'yo dude', + 'b': '本', + 'c': 'true', + 'd': datetime.date(2015, 2, 13), + 'e': '2014-08-01T11:23:45:00.0000Z' + }, + 'bazzle': { + 'id': 'bazzle', + 'a': '4', + 'b': '日本', + 'c': 'false', + 'd': datetime.date(2015, 1, 2), + 'e': '2014-05-01 11:16:45' + } } if 'mysql' in writer.connection.engine.driver: @@ -281,39 +423,267 @@ def test_change_type(self, writer): expected['bazzle']['c'] = '0' if 'pyodbc' in writer.connection.engine.driver: expected['bazzle']['c'] = '0' - # couldn't figure out how to make SQL Server convert date to ISO8601 - # see https://docs.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-2017#date-and-time-styles - expected['bazzle']['e'] = 'May 1 2014 11:16AM' + # MSSQL includes fractional seconds in returned value. + expected['bazzle']['e'] = '2014-05-01 11:16:45.0000000' for id, row in result.items(): assert id in expected assert dict(row) == expected[id] + def test_json_type(self, writer): + complex_object = { + 'poke1': { + 'name': 'snorlax', + 'color': 'blue', + 'attributes': { + 'strength': 10, + 'endurance': 10, + 'speed': 4, + }, + 'friends': [ + 'pikachu', + 'charmander', + ], + }, + 'poke2': { + 'name': 'pikachu', + 'color': 'yellow', + 'attributes': { + 'strength': 2, + 'endurance': 2, + 'speed': 8, + 'cuteness': 10, + }, + 'friends': [ + 'snorlax', + 'charmander', + ], + }, + } + with writer: + if not writer.is_postgres: + return + writer.write_table( + TableSpec( + **{ + 'name': 'foo_with_json', + 'headings': ['id', 'json_col'], + 'rows': [ + ['simple', { + 'k1': 'v1', + 'k2': 'v2' + }], + ['with_lists', { + 'l1': ['i1', 'i2'] + }], + ['complex', complex_object], + ], + 'data_types': [ + 'text', + 'json', + ] + } + ) + ) + + # We can use raw SQL instead of SqlAlchemy expressions because + # we built the DB above + with writer: + result = dict([(row['id'], row) for row in writer.connection + .execute('SELECT id, json_col FROM foo_with_json')]) + + assert len(result) == 3 + assert dict(result['simple']) == { + 'id': 'simple', + 'json_col': { + 'k1': 'v1', + 'k2': 'v2' + } + } + assert dict(result['with_lists']) == { + 'id': 'with_lists', + 'json_col': { + 'l1': ['i1', 'i2'] + } + } + assert dict(result['complex']) == { + 'id': 'complex', + 'json_col': complex_object + } def test_explicit_types(self, strict_writer): with strict_writer: - strict_writer.write_table(TableSpec(**{ - 'name': 'foo_explicit_types', - 'headings': ['id', 'a', 'b', 'c', 'd'], - 'rows': [ - ['bizzle', '1', 2, 3, '7'], - ['bazzle', '4', 5, 6, '8'], - ], - 'data_types': [ - 'text', - 'integer', - 'text', - None, - ] - })) - - # We can use raw SQL instead of SqlAlchemy expressions because we built the DB above + strict_writer.write_table( + TableSpec( + **{ + 'name': 'foo_explicit_types', + 'headings': ['id', 'a', 'b', 'c', 'd'], + 'rows': [ + ['bizzle', '1', 2, 3, '7'], + ['bazzle', '4', 5, 6, '8'], + ], + 'data_types': [ + 'text', + 'integer', + 'text', + None, + ] + } + ) + ) + + # We can use raw SQL instead of SqlAlchemy expressions because + # we built the DB above with strict_writer: - result = dict([(row['id'], row) for row in strict_writer.connection.execute( - 'SELECT id, a, b, c, d FROM foo_explicit_types' - )]) + result = dict([ + (row['id'], row) for row in strict_writer.connection + .execute('SELECT id, a, b, c, d FROM foo_explicit_types') + ]) assert len(result) == 2 # a casts strings to ints, b casts ints to text, c default falls back to ints, d default falls back to text - assert dict(result['bizzle']) == {'id': 'bizzle', 'a': 1, 'b': '2', 'c': 3, 'd': '7'} - assert dict(result['bazzle']) == {'id': 'bazzle', 'a': 4, 'b': '5', 'c': 6, 'd': '8'} + assert dict(result['bizzle']) == { + 'id': 'bizzle', + 'a': 1, + 'b': '2', + 'c': 3, + 'd': '7' + } + assert dict(result['bazzle']) == { + 'id': 'bazzle', + 'a': 4, + 'b': '5', + 'c': 6, + 'd': '8' + } + + def test_mssql_nvarchar_length_upsize(self, writer): + with writer: + if 'odbc' not in writer.connection.engine.driver: + return + + # Initialize a table with columns where we expect the + # "some_data" column to be of length 900 bytes, and the + # "big_data" column to be of nvarchar(max) + writer.write_table( + TableSpec( + **{ + 'name': + 'mssql_nvarchar_length', + 'headings': ['id', 'some_data', 'big_data'], + 'rows': [ + [ + 'bizzle', (b'\0' * 800).decode('utf-8'), + (b'\0' * 901).decode('utf-8') + ], + [ + 'bazzle', (b'\0' * 500).decode('utf-8'), + (b'\0' * 800).decode('utf-8') + ], + ] + } + ) + ) + + connection = writer.connection + + result = self._get_column_lengths( + connection, 'mssql_nvarchar_length' + ) + assert result['some_data'] == ('some_data', 'nvarchar', 900) + # nvarchar(max) is listed as -1 + assert result['big_data'] == ('big_data', 'nvarchar', -1) + + # put bigger data into "some_column" to ensure it is resized + # properly + writer.write_table( + TableSpec( + **{ + 'name': + 'mssql_nvarchar_length', + 'headings': ['id', 'some_data', 'big_data'], + 'rows': [[ + 'sizzle', (b'\0' * 901).decode('utf-8'), + (b'\0' * 901).decode('utf-8') + ],] + } + ) + ) + + result = self._get_column_lengths( + connection, 'mssql_nvarchar_length' + ) + assert result['some_data'] == ('some_data', 'nvarchar', -1) + assert result['big_data'] == ('big_data', 'nvarchar', -1) + + def test_mssql_nvarchar_length_downsize(self, writer): + with writer: + if 'odbc' not in writer.connection.engine.driver: + return + + # Initialize a table with NVARCHAR(max), and make sure + # smaller data doesn't reduce the size of the column + metadata = sqlalchemy.MetaData() + create_sql = sqlalchemy.schema.CreateTable( + sqlalchemy.Table( + 'mssql_nvarchar_length_downsize', + metadata, + sqlalchemy.Column( + 'id', + sqlalchemy.NVARCHAR(length=100), + primary_key=True + ), + sqlalchemy.Column( + 'some_data', sqlalchemy.NVARCHAR(length=None) + ), + ) + ).compile(writer.connection.engine) + metadata.create_all(writer.connection.engine) + + writer.write_table( + TableSpec( + **{ + 'name': + 'mssql_nvarchar_length', + 'headings': ['id', 'some_data'], + 'rows': [ + [ + 'bizzle', (b'\0' * 800).decode('utf-8'), + (b'\0' * 800).decode('utf-8') + ], + [ + 'bazzle', (b'\0' * 500).decode('utf-8'), + (b'\0' * 800).decode('utf-8') + ], + ] + } + ) + ) + result = self._get_column_lengths( + writer.connection, 'mssql_nvarchar_length_downsize' + ) + assert result['some_data'] == ('some_data', 'nvarchar', -1) + + def test_big_lump_of_poo(self, writer): + with writer: + writer.write_table( + TableSpec( + **{ + 'name': 'foo_with_emoji', + 'headings': ['id', 'fun_to_be_had'], + 'rows': [ + ['A steaming poo', '💩'], + ['2020', '😷'], + ], + } + ) + ) + + def _get_column_lengths(self, connection, table_name): + return { + row['COLUMN_NAME']: row for row in connection.execute( + "SELECT COLUMN_NAME, DATA_TYPE, CHARACTER_MAXIMUM_LENGTH " + "FROM INFORMATION_SCHEMA.COLUMNS " + "WHERE TABLE_NAME = '{}';".format(table_name) + ) + } diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..2397dc46 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,18 @@ +from commcare_export.writers import SqlTableWriter + + +class SqlWriterWithTearDown(SqlTableWriter): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tables = set() + + def write_table(self, table_spec): + super().write_table(table_spec) + if table_spec.rows: + self.tables.add(table_spec.name) + + def tear_down(self): + for table in self.tables: + self.engine.execute(f'DROP TABLE "{table}"') + self.tables = set()