Skip to content

Commit

Permalink
Merge pull request #781 from ufal/tf-data-2
Browse files Browse the repository at this point in the history
Towards TF dataset, part III
  • Loading branch information
jindrahelcl authored Jan 8, 2019
2 parents 2c71059 + e3f0f68 commit b384686
Show file tree
Hide file tree
Showing 94 changed files with 828 additions and 1,223 deletions.
6 changes: 1 addition & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ env:
- TEST_SUITE=mypy

python:
#- "2.7"
#- "3.4"
- "3.5"
#- "3.5-dev" # 3.5 development branch
#- "nightly" # currently points to 3.6-dev
- "3.6"

# commands to install dependencies
before_install:
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ python_speech_features
pygments
typeguard
sacrebleu
tensorflow>=1.10.0,<1.11
tensorflow>=1.12.0,<1.13
52 changes: 1 addition & 51 deletions neuralmonkey/checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,14 @@
constructing the computational graph.
"""


from typing import List, Optional, Iterable

from typing import List, Optional
import tensorflow as tf

from neuralmonkey.logging import log, debug
from neuralmonkey.dataset import Dataset
from neuralmonkey.runners.base_runner import BaseRunner


class CheckingException(Exception):
pass


def check_dataset_and_coders(dataset: Dataset,
runners: Iterable[BaseRunner]) -> None:
# pylint: disable=protected-access

data_list = []
for runner in runners:
for c in runner.feedables:
if hasattr(c, "data_id"):
data_list.append((getattr(c, "data_id"), c))
elif hasattr(c, "data_ids"):
data_list.extend([(d, c) for d in getattr(c, "data_ids")])
elif hasattr(c, "input_sequence"):
inpseq = getattr(c, "input_sequence")
if hasattr(inpseq, "data_id"):
data_list.append((getattr(inpseq, "data_id"), c))
elif hasattr(inpseq, "data_ids"):
data_list.extend(
[(d, c) for d in getattr(inpseq, "data_ids")])
else:
log("Input sequence: {} does not have a data attribute"
.format(str(inpseq)))
else:
log(("Coder: {} has neither an input sequence attribute nor a "
"a data attribute.").format(c))

debug("Found series: {}".format(str(data_list)), "checking")
missing = []

for (serie, coder) in data_list:
if serie not in dataset:
log("dataset {} does not have serie {}".format(
dataset.name, serie))
missing.append((coder, serie))

if missing:
formated = ["{} ({}, {}.{})" .format(serie, str(cod),
cod.__class__.__module__,
cod.__class__.__name__)
for cod, serie in missing]

raise CheckingException("Dataset '{}' is mising series {}:"
.format(dataset.name, ", ".join(formated)))


def assert_shape(tensor: tf.Tensor,
expected_shape: List[Optional[int]]) -> None:
"""Check shape of a tensor.
Expand Down
4 changes: 2 additions & 2 deletions neuralmonkey/checkpython.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys

if sys.version_info[0] < 3 or sys.version_info[1] < 5:
if sys.version_info[0] < 3 or sys.version_info[1] < 6:
print("Error:", file=sys.stderr)
print("Neural Monkey must use Python >= 3.5", file=sys.stderr)
print("Neural Monkey must use Python >= 3.6", file=sys.stderr)
print("Your Python is", sys.version, sys.executable, file=sys.stderr)
sys.exit(1)
20 changes: 0 additions & 20 deletions neuralmonkey/config/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import numpy as np

from neuralmonkey.dataset import BatchingScheme
from neuralmonkey.logging import warn
from neuralmonkey.tf_manager import get_default_tf_manager
from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer
Expand All @@ -34,25 +33,6 @@ def normalize_configuration(cfg: Namespace, train_mode: bool) -> None:
if cfg.tf_manager is None:
cfg.tf_manager = get_default_tf_manager()

if (cfg.batch_size is None) == (cfg.batching_scheme is None):
raise ValueError("You must specify either batch_size or "
"batching_scheme (not both).")

if cfg.batch_size is not None:
assert cfg.batching_scheme is None
cfg.batching_scheme = BatchingScheme(batch_size=cfg.batch_size)
else:
assert cfg.batching_scheme is not None
cfg.batch_size = cfg.batching_scheme.batch_size

if cfg.runners_batch_size is None:
cfg.runners_batch_size = cfg.batching_scheme.batch_size

cfg.runners_batching_scheme = BatchingScheme(
batch_size=cfg.runners_batch_size,
token_level_batching=cfg.batching_scheme.token_level_batching,
use_leftover_buckets=True)

cfg.evaluation = [(e[0], e[0], e[1]) if len(e) == 2 else e
for e in cfg.evaluation]

Expand Down
3 changes: 2 additions & 1 deletion neuralmonkey/config/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,13 @@ def _parse_class_name(string: str, vars_dict: VarsDict) -> ClassSymbol:


def _parse_value(string: str, vars_dict: VarsDict) -> Any:
"""Parse the value recursively according to the Nerualmonkey grammar.
"""Parse the value recursively according to the Nerual Monkey grammar.
Arguments:
string: the string to be parsed
vars_dict: a dictionary of variables for substitution
"""
string = string.strip()

if string in CONSTANTS:
return CONSTANTS[string]
Expand Down
Loading

0 comments on commit b384686

Please sign in to comment.