Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experimenting with pydantic #119

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion argschema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''argschema: flexible definition, validation and setting of parameters'''
from .fields import InputFile, InputDir, OutputFile, OptionList # noQA:F401
from .fields import OutputFile, OptionList # noQA:F401
from .schemas import ArgSchema # noQA:F401
from .argschema_parser import ArgSchemaParser # noQA:F401
from .deprecated import JsonModule, ModuleParameters # noQA:F401
Expand Down
4 changes: 2 additions & 2 deletions argschema/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
'''sub-module for custom marshmallow fields of general utility'''
from marshmallow.fields import * # noQA:F401
from marshmallow.fields import __all__ as __mmall__ # noQA:F401
from .files import OutputFile, InputDir, InputFile, OutputDir # noQA:F401
from .files import OutputFile, OutputDir # noQA:F401
from .numpyarrays import NumpyArray # noQA:F401
from .deprecated import OptionList # noQA:F401
from .loglevel import LogLevel # noQA:F401
from .slice import Slice # noQA:F401

__all__ = __mmall__ + ['OutputFile', 'InputDir', 'InputFile', 'OutputDir',
__all__ = __mmall__ + ['OutputFile','OutputDir',
'NumpyArray', 'OptionList', 'LogLevel', 'Slice']

# Python 2 subpackage (not module) * imports break if items in __all__
Expand Down
105 changes: 31 additions & 74 deletions argschema/fields/files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
'''marshmallow fields related to validating input and output file paths'''
import os
import marshmallow as mm
import tempfile
import errno
import sys
Expand Down Expand Up @@ -32,27 +31,28 @@ def __exit__(self, *args):
def validate_outpath(path):
try:
with NamedTemporaryFile(mode='w', dir=path) as tfile:
print(tfile)
tfile.write('0')
tfile.close()

except Exception as e:
if isinstance(e, OSError):
if e.errno == errno.ENOENT:
raise mm.ValidationError(
raise ValueError(
"%s is not in a directory that exists" % path)
elif e.errno == errno.EACCES:
raise mm.ValidationError(
raise ValueError(
"%s does not appear you can write to path" % path)
else:
raise mm.ValidationError(
raise ValueError(
"Unknown OSError: {}".format(e.message))
else:
raise mm.ValidationError(
raise ValueError(
"Unknown Exception: {}".format(e.message))


class OutputFile(mm.fields.Str):
"""OutputFile :class:`marshmallow.fields.Str` subclass which is a path to a
class OutputFile(str):
"""OutputFile :class:`str` subclass which is a path to a
file location that can be written to by the current user
(presently tested by opening a temporary file to that
location)
Expand All @@ -65,7 +65,12 @@ class OutputFile(mm.fields.Str):

"""

def _validate(self, value):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value):
"""

Parameters
Expand All @@ -86,104 +91,56 @@ def _validate(self, value):
try:
path = os.path.dirname(value)
except Exception as e: # pragma: no cover
raise mm.ValidationError(
raise ValueError(
"%s cannot be os.path.dirname-ed" % value) # pragma: no cover

validate_outpath(path)

return cls(value)

class OutputDirModeException(Exception):
pass

class OutputDir(mm.fields.Str):
"""OutputDir is a :class:`marshmallow.fields.Str` subclass which is a path to
class OutputDir(str):
"""OutputDir is a :class:`str` subclass which is a path to
a location where this module will write files. Validation will check that
the directory exists and create the directory if it is not present,
and will fail validation if the directory cannot be created or cannot be
written to.

Parameters
==========
mode: str
mode to create directory
*args:
smae as passed to marshmallow.fields.Str
**kwargs:
same as passed to marshmallow.fields.Str
"""

def __init__(self, mode=None, *args, **kwargs):
self.mode = mode
if (self.mode is not None) & (sys.platform == "win32"):
raise OutputDirModeException(
"Setting mode of OutputDir supported only on posix systems")
super(OutputDir, self).__init__(*args, **kwargs)

def _validate(self, value):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value):
if not os.path.isdir(value):
try:
os.makedirs(value)
if self.mode is not None:
os.chmod(value, self.mode)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise mm.ValidationError(
raise ValueError(
"{} is not a directory and you cannot create it".format(
value)
)
if self.mode is not None:
try:
assert((os.stat(value).st_mode & 0o777) == self.mode)
except AssertionError:
raise mm.ValidationError(
"{} does not have the mode ({}) that was specified ".format(
value, self.mode)
)
except os.error:
raise mm.ValidationError(
"cannot get os.stat of {}".format(value)
)

# use outputfile to test that a file in this location is a valid path
validate_outpath(value)

return value


def validate_input_path(value):
if not os.path.isfile(value):
raise mm.ValidationError("%s is not a file" % value)
raise ValueError("%s is not a file" % value)
else:
try:
with open(value) as f:
pass
except Exception as value:
raise mm.ValidationError("%s is not readable" % value)

class InputDir(mm.fields.Str):
"""InputDir is :class:`marshmallow.fields.Str` subclass which is a path to a
a directory that exists and that the user can access
(presently checked with os.access)
"""

def _validate(self, value):
if not os.path.isdir(value):
raise mm.ValidationError("%s is not a directory")

if sys.platform == "win32":
try:
x = list(os.scandir(value))
except PermissionError:
raise mm.ValidationError(
"%s is not a readable directory" % value)
else:
if not os.access(value, os.R_OK):
raise mm.ValidationError(
"%s is not a readable directory" % value)


class InputFile(mm.fields.Str):
"""InputDile is a :class:`marshmallow.fields.Str` subclass which is a path to a
file location which can be read by the user
(presently passes os.path.isfile and os.access = R_OK)
"""
raise ValueError("%s is not readable" % value)

def _validate(self, value):
validate_input_path(value)
88 changes: 54 additions & 34 deletions argschema/schemas.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,64 @@
import marshmallow as mm
from .fields import LogLevel, InputFile, OutputFile
from pydantic import BaseModel, Field, FilePath
from pydantic.main import ModelMetaclass
from typing import get_origin
from enum import Enum
import logging
import argparse
from .fields import OutputFile

class LogLevel(Enum):
DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL

class DefaultSchema(mm.Schema):
"""mm.Schema class with support for making fields default to
values defined by that field's arguments.
"""
class ArgSchema(BaseModel):
input_json: FilePath = Field('input.json', description='zee inputs')
output_json: OutputFile = Field('output.json', description='zee outputs')
log_level: LogLevel = Field(logging.ERROR, description='zee log level')

@mm.pre_load
def make_object(self, in_data, **kwargs):
"""marshmallow.pre_load decorated function for applying defaults on deserialation
@classmethod
def from_args(cls, a):
arg_data = vars(a)
with open(arg_data['input_json'],'r') as f:
input_data = json.load(f)

Parameters
----------
in_data :
input_data['input_json'] = arg_data['input_json']
input_data['output_json'] = arg_data['output_json']
input_data['log_level'] = arg_data['log_level']

return populate_schema_from_data(cls, input_data, arg_data)

@classmethod
def argument_parser(cls, *args, **kwargs):
parser = argparse.ArgumentParser(*args, **kwargs)
add_arguments_from_schema(parser, cls)
return parser

Returns
-------
dict
a dictionary with default values applied
def populate_schema_from_data(schema, input_data, arg_data):
xdata = {}

"""
for name, field in self.fields.items():
if name not in in_data:
if field.default is not mm.missing:
in_data[name] = field.default
return in_data
for field_name, field in schema.__fields__.items():
if isinstance(field.outer_type_, ModelMetaclass):
sub_input_data = input_data[field_name]
sub_arg_data = { k.replace(f'{field_name}.',''):v for k,v in arg_data.items() if k.startswith(field_name)}
xdata[field_name] = populate_schema_from_data(field.type_, sub_input_data, sub_arg_data)
else:
arg_value = arg_data.get(field_name, None)
xdata[field_name] = input_data[field_name] if arg_value is None else arg_value

return schema(**xdata)

def add_arguments_from_schema(parser, schema, parent_prefix=''):
for field_name, field in schema.__fields__.items():
fn = field_name.replace('_','-')

if isinstance(field.outer_type_, ModelMetaclass):
add_arguments_from_schema(parser, field.outer_type_, parent_prefix=f'{parent_prefix}{fn}.')
elif get_origin(field.outer_type_) is list:
parser.add_argument(f'--{parent_prefix}{fn}', nargs='+', type=field.type_, default=field.default, help=field.field_info.description)
else:
parser.add_argument(f'--{parent_prefix}{fn}', type=field.type_, default=field.default, help=field.field_info.description)

class ArgSchema(DefaultSchema):
"""The base marshmallow schema used by ArgSchemaParser to identify
input_json and output_json files and the log_level
"""

input_json = InputFile(
description="file path of input json file")

output_json = OutputFile(
description="file path to output json file")
log_level = LogLevel(
default='ERROR',
description="set the logging level of the module")
Loading