Skip to content

Commit

Permalink
Enable including Java worker for ray start command (#3838)
Browse files Browse the repository at this point in the history
  • Loading branch information
jovany-wang authored and raulchen committed Feb 4, 2019
1 parent 7ef830b commit e1c68a0
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 17 deletions.
34 changes: 34 additions & 0 deletions java/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,38 @@
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>copy-dependencies-to-build</id>
<phase>package</phase>
<goals>
<goal>copy-dependencies</goal>
</goals>
<configuration>
<outputDirectory>${basedir}/../../build/java</outputDirectory>
<overWriteReleases>false</overWriteReleases>
<overWriteSnapshots>false</overWriteSnapshots>
<overWriteIfNewer>true</overWriteIfNewer>
</configuration>
</execution>
</executions>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>2.3.1</version>
<configuration>
<outputDirectory>${basedir}/../../build/java</outputDirectory>
</configuration>
</plugin>
</plugins>
</build>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public void start() throws Exception {
}
redisClient = new RedisClient(rayConfig.getRedisAddress());

// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);

rayletClient = new RayletClientImpl(
Expand Down
1 change: 0 additions & 1 deletion java/test/src/main/java/org/ray/api/test/ActorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
import org.ray.api.function.RayFunc2;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayActorImpl;
import org.testng.Assert;
Expand Down
115 changes: 115 additions & 0 deletions java/test/src/main/java/org/ray/api/test/MultiLanguageClusterTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package org.ray.api.test;

import com.google.common.collect.ImmutableList;
import java.io.File;
import java.io.IOException;
import java.lang.ProcessBuilder.Redirect;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

/**
* Test starting a ray cluster with multi-language support.
*/
public class MultiLanguageClusterTest {

private static final Logger LOGGER = LoggerFactory.getLogger(MultiLanguageClusterTest.class);

private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";

@RayRemote
public static String echo(String word) {
return word;
}

/**
* Execute an external command.
* @return Whether the command succeeded.
*/
private boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
try {
LOGGER.info("Executing command: {}", String.join(" ", command));
Process process = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT)
.redirectError(Redirect.INHERIT).start();
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}

@BeforeMethod
public void setUp() {
// Check whether 'ray' command is installed.
boolean rayCommandExists = executeCommand(ImmutableList.of("which", "ray"), 5);
if (!rayCommandExists) {
throw new SkipException("Skipping test, because ray command doesn't exist.");
}

// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
file.delete();
}
}

// Start ray cluster.
final List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
"--redis-port=6379",
"--include-java",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
"--java-worker-options=-classpath ../../build/java/*:../../java/test/target/*"
);
if (!executeCommand(startCommand, 10)) {
throw new RuntimeException("Couldn't start ray cluster.");
}

// Connect to the cluster.
System.setProperty("ray.home", "../..");
System.setProperty("ray.redis.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
Ray.init();
}

@AfterMethod
public void tearDown() {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.home");
System.clearProperty("ray.redis.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");

// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10)) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}

@Test
public void testMultiLanguageCluster() {
RayObject<String> obj = Ray.call(MultiLanguageClusterTest::echo, "hello");
Assert.assertEquals("hello", obj.get());
}

}
10 changes: 9 additions & 1 deletion python/ray/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def __init__(self, ray_params, head=False, shutdown_at_exit=True):

if head:
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
else:
redis_client = ray.services.create_redis_client(
ray_params.redis_address, ray_params.redis_password)
ray_params.include_java = (
ray.services.include_java_from_redis(redis_client))

self._ray_params = ray_params
self._config = (json.loads(ray_params._internal_config)
Expand Down Expand Up @@ -224,7 +229,10 @@ def start_raylet(self, use_valgrind=False, use_profiler=False):
use_profiler=use_profiler,
stdout_file=stdout_file,
stderr_file=stderr_file,
config=self._config)
config=self._config,
include_java=self._ray_params.include_java,
java_worker_options=self._ray_params.java_worker_options,
)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]

Expand Down
17 changes: 14 additions & 3 deletions python/ray/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class RayParams(object):
monitor the log files for all processes on this node and push their
contents to Redis.
autoscaling_config: path to autoscaling config file.
include_java (bool): If True, the raylet backend can also support
Java worker.
java_worker_options (str): The command options for Java worker.
_internal_config (str): JSON configuration for overriding
RayConfig defaults. For testing purposes ONLY.
"""
Expand Down Expand Up @@ -106,6 +109,8 @@ def __init__(self,
temp_dir=None,
include_log_monitor=None,
autoscaling_config=None,
include_java=False,
java_worker_options=None,
_internal_config=None):
self.object_id_seed = object_id_seed
self.redis_address = redis_address
Expand Down Expand Up @@ -136,6 +141,8 @@ def __init__(self,
self.temp_dir = temp_dir
self.include_log_monitor = include_log_monitor
self.autoscaling_config = autoscaling_config
self.include_java = include_java
self.java_worker_options = java_worker_options
self._internal_config = _internal_config
self._check_usage()

Expand All @@ -146,7 +153,7 @@ def update(self, **kwargs):
kwargs: The keyword arguments to set corresponding fields.
"""
for arg in kwargs:
if (hasattr(self, arg)):
if hasattr(self, arg):
setattr(self, arg, kwargs[arg])
else:
raise ValueError("Invalid RayParams parameter in"
Expand All @@ -161,7 +168,7 @@ def update_if_absent(self, **kwargs):
kwargs: The keyword arguments to set corresponding fields.
"""
for arg in kwargs:
if (hasattr(self, arg)):
if hasattr(self, arg):
if getattr(self, arg) is None:
setattr(self, arg, kwargs[arg])
else:
Expand All @@ -180,6 +187,10 @@ def _check_usage(self):
"num_gpus instead.")

if self.num_workers is not None:
raise Exception(
raise ValueError(
"The 'num_workers' argument is deprecated. Please use "
"'num_cpus' instead.")

if self.include_java is None and self.java_worker_options is not None:
raise ValueError("Should not specify `java-worker-options` "
"without providing `include-java`.")
26 changes: 22 additions & 4 deletions python/ray/scripts/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,17 @@ def cli(logging_level, logging_format):
"--temp-dir",
default=None,
help="manually specify the root temporary dir of the Ray process")
@click.option(
"--include-java",
is_flag=True,
default=None,
help="Enable Java worker support.")
@click.option(
"--java-worker-options",
required=False,
default=None,
type=str,
help="Overwrite the options to start Java workers.")
@click.option(
"--internal-config",
default=None,
Expand All @@ -212,8 +223,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_memory, num_workers, num_cpus, num_gpus, resources, head,
no_ui, block, plasma_directory, huge_pages, autoscaling_config,
no_redirect_worker_output, no_redirect_output,
plasma_store_socket_name, raylet_socket_name, temp_dir,
internal_config):
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
java_worker_options, internal_config):
# Convert hostnames to numerical IP address.
if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address)
Expand Down Expand Up @@ -245,6 +256,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir,
include_java=include_java,
java_worker_options=java_worker_options,
_internal_config=internal_config)

if head:
Expand Down Expand Up @@ -280,7 +293,9 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
include_webui=(not no_ui),
autoscaling_config=autoscaling_config)
autoscaling_config=autoscaling_config,
include_java=False,
)

node = ray.node.Node(ray_params, head=True, shutdown_at_exit=False)
redis_address = node.redis_address
Expand Down Expand Up @@ -322,6 +337,10 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
if no_ui:
raise Exception("If --head is not passed in, the --no-ui flag is "
"not relevant.")
if include_java is not None:
raise ValueError("--include-java should only be set for the head "
"node.")

redis_ip_address, redis_port = redis_address.split(":")

# Wait for the Redis server to be started. And throw an exception if we
Expand All @@ -348,7 +367,6 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
check_no_existing_redis_clients(ray_params.node_ip_address,
redis_client)
ray_params.update(redis_address=redis_address)

node = ray.node.Node(ray_params, head=False, shutdown_at_exit=False)
logger.info("\nStarted Ray on this node. If you wish to terminate the "
"processes that have been started, run\n\n"
Expand Down
Loading

0 comments on commit e1c68a0

Please sign in to comment.