-
Notifications
You must be signed in to change notification settings - Fork 5
/
dataset.py
12 lines (11 loc) · 1.37 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_enum('dataset', 'stacks', ['stacks','navigation','regex'], 'dataset to use')
def load():
global get_session_ids, get_session_data, state_to_variable, output_to_variable, output_from_variable, get_states_and_actions, LanguageModule, Encoder, Decoder, loss, state_to_variable_batch, output_to_variable_batch, output_from_variable_batch, LSTMLanguageModule
if FLAGS.dataset == 'stacks':
from shrdlurn.stack_dataset import get_session_ids, get_session_data, state_to_variable, output_to_variable, output_from_variable, get_states_and_actions, LanguageModule, Encoder, Decoder, loss, state_to_variable_batch, output_to_variable_batch, output_from_variable_batch, LSTMLanguageModule
elif FLAGS.dataset == 'navigation':
from navigation.nav_dataset import get_session_ids, get_session_data, state_to_variable, output_to_variable, output_from_variable, get_states_and_actions, LanguageModule, Encoder, Decoder, loss, state_to_variable_batch, output_to_variable_batch, output_from_variable_batch, LSTMLanguageModule
else:
from regexp.regex_dataset import get_session_ids, get_session_data, state_to_variable, output_to_variable, output_from_variable, get_states_and_actions, LanguageModule, Encoder, Decoder, loss, state_to_variable_batch, output_to_variable_batch, output_from_variable_batch, LSTMLanguageModule