Skip to content

Commit

Permalink
Merge pull request #606 from sloweater/local_files
Browse files Browse the repository at this point in the history
Make task try to open input locally before going over the network.
  • Loading branch information
pooya committed Feb 8, 2015
2 parents 309151e + bdcddc7 commit c8fd71c
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 7 deletions.
3 changes: 2 additions & 1 deletion lib/disco/schemes/scheme_disco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ def input_stream(fd, size, url, params):
"""
Opens the path on host locally if it is local, otherwise over http.
"""
file = open(url, task=globals().get('Task'))
import disco.worker
file = open(url, task=disco.worker.active_task)
return file, len(file), file.url
3 changes: 2 additions & 1 deletion lib/disco/schemes/scheme_discodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def Open(url, task=None):
return discodb

def input_stream(fd, size, url, params):
return Open(url, task=globals().get('Task')), size, url
import disco.worker
return Open(url, task=disco.worker.active_task), size, url

class DiscoDBOutput(object):
def __init__(self, stream, params):
Expand Down
3 changes: 2 additions & 1 deletion lib/disco/schemes/scheme_hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ def open(url, task=None):
return comm.open_url(http_url)

def input_stream(fd, size, url, params):
file = open(url, task=globals().get('Task'))
import disco.worker
file = open(url, task=disco.worker.active_task)
return file, len(file), file.url
3 changes: 2 additions & 1 deletion lib/disco/schemes/scheme_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ def open(url, task=None):

def input_stream(fd, sze, url, params):
"""Opens the specified url using an http client."""
file = open(url, task=globals().get('Task'))
import disco.worker
file = open(url, task=disco.worker.active_task)
return file, len(file), file.url
27 changes: 24 additions & 3 deletions lib/disco/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
#. (node) worker requests the :class:`disco.task.Task` from the master
#. (node) worker runs the :term:`task` and reports the output to the master
"""
import os, sys, time, traceback
import os, sys, time, traceback, random

from disco.compat import basestring, force_utf8
from disco.error import DataError
Expand All @@ -63,6 +63,11 @@
# Maximum amount of time a task might take to finish.
DISCO_WORKER_MAX_TIME = 24 * 60 * 60

# Use active_task as a global variable.
# I will set this when a task is running, and then access it from utilities that
# need task data like ddfs directory, etc.
active_task = None

class MessageWriter(object):
def __init__(self, worker):
self.worker = worker
Expand Down Expand Up @@ -480,14 +485,30 @@ def __str__(self):
class ReplicaIter(object):
def __init__(self, input):
self.input, self.used = input, set()
self.checked_local = False

def __iter__(self):
return self

def next(self):
from disco.util import urlsplit
replicas = dict(self.input.replicas)
repl_ids = set(replicas) - self.used
for repl_id in repl_ids:
repl_ids = list(set(replicas) - self.used)
if not self.checked_local and active_task: # Try to favor opening a local file
self.checked_local = True
for repl_id in repl_ids:
replica = replicas[repl_id]
scheme, netloc, rest = urlsplit(replica,
localhost=active_task.host,
ddfs_data=active_task.ddfs_data,
disco_data=active_task.disco_data)
if scheme == 'file':
self.used.add(repl_id)
if os.path.exists(rest): # file might not exist due to replica rebalancing
return replica
repl_ids.remove(repl_id)
if repl_ids:
repl_id = random.choice(repl_ids)
self.used.add(repl_id)
return replicas[repl_id]
self.input.unavailable(self.used)
Expand Down
1 change: 1 addition & 0 deletions lib/disco/worker/classic/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def get(key):
def run(self, task, job, **jobargs):
global Task
Task = task
worker.active_task = task
for key in self:
self[key] = self.getitem(key, job, jobargs)
assert self['version'] == '{0[0]}.{0[1]}'.format(sys.version_info[:2]), "Python version mismatch"
Expand Down
1 change: 1 addition & 0 deletions lib/disco/worker/pipeline/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def run(self, task, job, **jobargs):
# Entry point into the executing pipeline worker task. This
# initializes the task environment, sets up the current stage,
# and then executes it.
worker.active_task = task
for key in self:
self[key] = self.getitem(key, job, jobargs)
sys_version = '{0[0]}.{0[1]}'.format(sys.version_info[:2])
Expand Down

0 comments on commit c8fd71c

Please sign in to comment.