From a67f6b2829e8f9649987c1bb990f0cfa1fbe738d Mon Sep 17 00:00:00 2001 From: Song Jiaming Date: Wed, 8 Sep 2021 11:09:02 +0000 Subject: [PATCH] FL server and client (#4537) --- ppml/pom.xml | 219 ++++++++++++++ ppml/ppml-conf.yaml | 8 + ppml/python/README.md | 1 + .../intel/analytics/zoo/ppml/FLClient.java | 276 ++++++++++++++++++ .../intel/analytics/zoo/ppml/FLHelper.java | 44 +++ .../intel/analytics/zoo/ppml/FLServer.java | 67 +++++ .../analytics/zoo/ppml/generated/FLProto.java | 0 .../zoo/ppml/generated/PSIServiceGrpc.java | 0 .../generated/ParameterServerServiceGrpc.java | 0 .../zoo/ppml/psi/PSIServiceImpl.java | 180 ++++++++++++ .../zoo/ppml/psi/PsiIntersection.java | 140 +++++++++ .../intel/analytics/zoo/ppml/psi/Utils.java | 58 ++++ .../zoo/ppml/psi/test/BenchmarkClient.java | 127 ++++++++ .../analytics/zoo/ppml/psi/test/Client.java | 88 ++++++ .../zoo/ppml/psi/test/NetworkCheckClient.java | 110 +++++++ .../zoo/ppml/psi/test/TestUtils.java | 267 +++++++++++++++++ ppml/src/main/proto/FLProto.proto | 219 ++++++++++++++ ppml/src/main/resources/psi/psi-conf.yaml | 8 + .../intel/analytics/zoo/ppml/VFLServer.scala | 20 ++ .../analytics/zoo/grpc/AbstractZooGrpc.java | 60 ++++ .../analytics/zoo/grpc/ZooGrpcClient.java | 68 +++++ .../analytics/zoo/grpc/ZooGrpcServer.java | 99 +++++-- .../analytics/zoo/utils/ConfigParser.java | 50 ++++ .../analytics/zoo/utils/ConfigParserTest.java | 68 +++++ .../intel/analytics/zoo/utils/TestHelper.java | 35 +++ 25 files changed, 2190 insertions(+), 22 deletions(-) create mode 100644 ppml/pom.xml create mode 100644 ppml/ppml-conf.yaml create mode 100644 ppml/python/README.md create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/FLClient.java create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/FLHelper.java create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/FLServer.java rename {zoo => ppml}/src/main/java/com/intel/analytics/zoo/ppml/generated/FLProto.java (100%) rename {zoo => ppml}/src/main/java/com/intel/analytics/zoo/ppml/generated/PSIServiceGrpc.java (100%) rename {zoo => ppml}/src/main/java/com/intel/analytics/zoo/ppml/generated/ParameterServerServiceGrpc.java (100%) create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/PSIServiceImpl.java create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/PsiIntersection.java create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/Utils.java create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/BenchmarkClient.java create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/Client.java create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/NetworkCheckClient.java create mode 100644 ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/TestUtils.java create mode 100644 ppml/src/main/proto/FLProto.proto create mode 100644 ppml/src/main/resources/psi/psi-conf.yaml create mode 100644 ppml/src/main/scala/com/intel/analytics/zoo/ppml/VFLServer.scala create mode 100644 zoo/src/main/java/com/intel/analytics/zoo/grpc/AbstractZooGrpc.java create mode 100644 zoo/src/main/java/com/intel/analytics/zoo/grpc/ZooGrpcClient.java create mode 100644 zoo/src/main/java/com/intel/analytics/zoo/utils/ConfigParser.java create mode 100644 zoo/src/test/java/com/intel/analytics/zoo/utils/ConfigParserTest.java create mode 100644 zoo/src/test/java/com/intel/analytics/zoo/utils/TestHelper.java diff --git a/ppml/pom.xml b/ppml/pom.xml new file mode 100644 index 00000000000..0ad3bf96458 --- /dev/null +++ b/ppml/pom.xml @@ -0,0 +1,219 @@ + + 4.0.0 + com.intel.analytics.zoo + analytics-zoo-ppml + jar + 0.1.0-SNAPSHOT + analytics-zoo-ppml + https://github.com/analytics-zoo/narwhal/tree/master/Trusted_FL/src/PPML + + + + central + Maven Repository + https://repo1.maven.org/maven2 + + true + + + false + + + + + ossrh + ossrh repository + https://oss.sonatype.org/content/repositories/snapshots + + true + + + + + + 1.33.0 + 3.12.0 + 3.12.0 + 0.12.2 + 2.4 + 2.4.3 + UTF-8 + 2.12 + 2.12.8 + 2.1.0 + 3.0.7 + compile + + 1.8 + 1.8 + + + + + + + + io.grpc + grpc-bom + ${grpc.version} + pom + import + + + + + + + io.grpc + grpc-protobuf + + + io.grpc + grpc-stub + + + io.grpc + grpc-netty + + + commons-cli + commons-cli + 1.4 + + + org.scalatest + scalatest_${scala.major.version} + ${scalatest.version} + test + + + org.scala-lang + scala-compiler + ${scala.version} + ${spark-scope} + + + org.scala-lang + scala-reflect + ${scala.version} + ${spark-scope} + + + org.scala-lang + scala-library + ${scala.version} + ${scala-library-scope} + + + org.scala-lang + scala-actors + 2.11.8 + ${spark-scope} + + + org.apache.spark + spark-mllib_${scala.major.version} + ${spark.version} + ${spark-scope} + + + io.netty + netty + + + io.netty + netty-all + + + + + org.scala-lang + scalap + ${scala.version} + ${spark-scope} + + + junit + junit + 4.12 + test + + + com.intel.analytics.zoo + analytics-zoo-bigdl_0.13.0-spark_2.4.6 + 0.12.0-SNAPSHOT + + + + + + kr.motd.maven + os-maven-plugin + 1.6.2 + + + + + net.alchim31.maven + scala-maven-plugin + 3.4.2 + + + + compile + testCompile + + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + jar-with-dependencies + + + + + package + + single + + + + + + org.scalatest + scalatest-maven-plugin + 1.0 + + ${project.build.directory}/surefire-reports + . + narwhal-test-report.txt + + true + + -Xmx6g -XX:MaxPermSize=1g + + + + test + + test + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + 8 + 8 + + + + + diff --git a/ppml/ppml-conf.yaml b/ppml/ppml-conf.yaml new file mode 100644 index 00000000000..444356b1f98 --- /dev/null +++ b/ppml/ppml-conf.yaml @@ -0,0 +1,8 @@ +servicesList: psi + +# Server property +# serverPort: + +# Client property +# clientTarget: +# taskID: \ No newline at end of file diff --git a/ppml/python/README.md b/ppml/python/README.md new file mode 100644 index 00000000000..b0c0670d72b --- /dev/null +++ b/ppml/python/README.md @@ -0,0 +1 @@ +# Analytics Zoo PPML Python API \ No newline at end of file diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLClient.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLClient.java new file mode 100644 index 00000000000..7d80b9823df --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLClient.java @@ -0,0 +1,276 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml; + +import com.intel.analytics.zoo.grpc.ZooGrpcClient; +import com.intel.analytics.zoo.ppml.generated.FLProto.*; +import com.intel.analytics.zoo.ppml.generated.PSIServiceGrpc; +import com.intel.analytics.zoo.ppml.generated.ParameterServerServiceGrpc; +import com.intel.analytics.zoo.ppml.psi.Utils; +import io.grpc.StatusRuntimeException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.*; +import java.util.concurrent.TimeUnit; + +public class FLClient extends ZooGrpcClient { + private static final Logger logger = LoggerFactory.getLogger(FLClient.class); + protected String taskID; + protected String clientID = UUID.randomUUID().toString(); + protected String salt; + protected int splitSize = 1000000; + private static PSIServiceGrpc.PSIServiceBlockingStub blockingStubPSI; + private static ParameterServerServiceGrpc.ParameterServerServiceBlockingStub blockingStubPS; + public FLClient(String[] args) { + super(args); + } + + @Override + protected void parseConfig() throws IOException { + FLHelper flHelper = getConfigFromYaml(FLHelper.class, configPath); + if (flHelper != null) { + serviceList = flHelper.servicesList; + target = flHelper.clientTarget; + taskID = flHelper.taskID; + } + super.parseConfig(); + } + + @Override + public void loadServices() { + for (String service : serviceList.split(",")) { + if (service.equals("psi")) { + blockingStubPSI = PSIServiceGrpc.newBlockingStub(channel); + } else if (service.equals("ps")) { + // TODO: algorithms stub add here + } else { + logger.warn("Type is not supported, skipped. Type: " + service); + } + } + + } + public String getSalt() { + if (this.taskID.isEmpty()) { + this.taskID = Utils.getRandomUUID(); + } + return getSalt(this.taskID, 2, "Test"); + } + + /** + * For PSI usage only + * To get salt from FL Server, will get a new one if its salt does not exist on server + * @param name String, taskID + * @param clientNum int, client number + * @param secureCode String, secure code + * @return String, the salt get from server + */ + public String getSalt(String name, int clientNum, String secureCode) { + logger.info("Processing task with taskID: " + name + " ..."); + SaltRequest request = SaltRequest.newBuilder() + .setTaskId(name) + .setClientNum(clientNum) + .setSecureCode(secureCode).build(); + SaltReply response; + try { + response = blockingStubPSI.getSalt(request); + } catch (StatusRuntimeException e) { + throw new RuntimeException("RPC failed: " + e.getMessage()); + } + if (!response.getSaltReply().isEmpty()) { + salt = response.getSaltReply(); + } + return response.getSaltReply(); + } + + /** + * For PSI usage only + * Upload local set to FL Server in VFL + * @param hashedIdArray List of String, the set trained at local + */ + public void uploadSet(List hashedIdArray) { + int numSplit = Utils.getTotalSplitNum(hashedIdArray, splitSize); + int split = 0; + while (split < numSplit) { + List splitArray = Utils.getSplit(hashedIdArray, split, numSplit, splitSize); + UploadSetRequest request = UploadSetRequest.newBuilder() + .setTaskId(taskID) + .setSplit(split) + .setNumSplit(numSplit) + .setSplitLength(splitSize) + .setTotalLength(hashedIdArray.size()) + .setClientId(clientID) + .addAllHashedID(splitArray) + .build(); + try { + blockingStubPSI.uploadSet(request); + } catch (StatusRuntimeException e) { + throw new RuntimeException("RPC failed: " + e.getMessage()); + } + split ++; + } + } + + /** + * For PSI usage only + * Download intersection from FL Server in VFL + * @return List of String, the intersection downloaded + */ + public List downloadIntersection() { + List result = new ArrayList(); + try { + logger.info("Downloading 0th intersection"); + DownloadIntersectionRequest request = DownloadIntersectionRequest.newBuilder() + .setTaskId(taskID) + .setSplit(0) + .build(); + DownloadIntersectionResponse response = blockingStubPSI.downloadIntersection(request); + logger.info("Downloaded 0th intersection"); + result.addAll(response.getIntersectionList()); + for (int i = 1; i < response.getNumSplit(); i++) { + request = DownloadIntersectionRequest.newBuilder() + .setTaskId(taskID) + .setSplit(i) + .build(); + logger.info("Downloading " + i + "th intersection"); + response = blockingStubPSI.downloadIntersection(request); + logger.info("Downloaded " + i + "th intersection"); + result.addAll(response.getIntersectionList()); + } + assert(result.size() == response.getTotalLength()); + } catch (StatusRuntimeException e) { + throw new RuntimeException("RPC failed: " + e.getMessage()); + } + return result; + } + + public void shutdown() { + try { + channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + logger.error("Shutdown Client Error" + e.getMessage()); + } + } + + + public DownloadResponse downloadTrain(String modelName, int flVersion) { + logger.info("Download the following data:"); + TableMetaData metadata = TableMetaData.newBuilder() + .setName(modelName).setVersion(flVersion + 1).build(); + DownloadRequest downloadRequest = DownloadRequest.newBuilder().setMetaData(metadata).build(); + return blockingStubPS.downloadTrain(downloadRequest); + } + + public UploadResponse uploadTrain(Table data) { + + UploadRequest uploadRequest = UploadRequest + .newBuilder() + .setData(data) + .setClientuuid(clientUUID) + .build(); + + logger.info("Upload the following data:"); + logger.info("Upload Data Name:" + data.getMetaData().getName()); + logger.info("Upload Data Version:" + data.getMetaData().getVersion()); + logger.debug("Upload Data" + data.getTableMap()); +// logger.info("Upload" + data.getTableMap().get("weights").getTensorList().subList(0, 5)); + + UploadResponse uploadResponse = blockingStubPS.uploadTrain(uploadRequest); + return uploadResponse; + } + + public EvaluateResponse evaluate(Table data, boolean lastBatch) { + EvaluateRequest eRequest = EvaluateRequest + .newBuilder() + .setData(data) + .setClientuuid(clientUUID) + .setLast(lastBatch) + .build(); + + return blockingStubPS.uploadEvaluate(eRequest); + } + + public UploadResponse uploadSplit(DataSplit ds) { + UploadSplitRequest uploadRequest = UploadSplitRequest + .newBuilder() + .setSplit(ds) + .setClientuuid(clientUUID) + .build(); + + return blockingStubPS.uploadSplitTrain(uploadRequest); + } + + /*** + * XGBoost download aggregated best split + * @param treeID + * @return + */ + public DownloadSplitResponse downloadSplit( + String treeID, + String nodeID) { + DownloadSplitRequest downloadRequest = DownloadSplitRequest + .newBuilder() + .setTreeID(treeID) + .setNodeID(nodeID) + .setClientuuid(clientUUID) + .build(); + return blockingStubPS.downloadSplitTrain(downloadRequest); + } + + public UploadResponse uploadTreeEval( + List boostEval) { + UploadTreeEvalRequest uploadTreeEvalRequest = UploadTreeEvalRequest + .newBuilder() + .setClientuuid(clientUUID) + .addAllTreeEval(boostEval) + .build(); + + return blockingStubPS.uploadTreeEval(uploadTreeEvalRequest); + } + + public PredictTreeResponse uploadTreePred( + List boostEval) { + PredictTreeRequest request = PredictTreeRequest + .newBuilder() + .setClientuuid(clientUUID) + .addAllTreeEval(boostEval) + .build(); + + return blockingStubPS.predictTree(request); + } + + + public UploadResponse uploadTreeLeaves( + String treeID, + List treeIndexes, + List treeOutput + ) { + TreeLeaves treeLeaves = TreeLeaves + .newBuilder() + .setTreeID(treeID) + .addAllLeafIndex(treeIndexes) + .addAllLeafOutput(treeOutput) + .build(); + UploadTreeLeavesRequest uploadTreeLeavesRequest = UploadTreeLeavesRequest + .newBuilder() + .setClientuuid(clientUUID) + .setTreeLeaves(treeLeaves) + .build(); + return blockingStubPS.uploadTreeLeaves(uploadTreeLeavesRequest); + } +} diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLHelper.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLHelper.java new file mode 100644 index 00000000000..e93617c4f60 --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLHelper.java @@ -0,0 +1,44 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml; + +public class FLHelper { + String servicesList; + + // Server property + int serverPort = 8980; + + // Client property + String clientTarget = "localhost:8980"; + String taskID = "taskID"; + + public void setServicesList(String servicesList) { + this.servicesList = servicesList; + } + + public void setServerPort(int serverPort) { + this.serverPort = serverPort; + } + + public void setClientTarget(String clientTarget) { + this.clientTarget = clientTarget; + } + + public void setTaskID(String taskID) { + this.taskID = taskID; + } +} diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLServer.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLServer.java new file mode 100644 index 00000000000..c9494fcdb29 --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/FLServer.java @@ -0,0 +1,67 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml; + +import com.intel.analytics.zoo.grpc.ZooGrpcServer; +import com.intel.analytics.zoo.ppml.psi.PSIServiceImpl; +import io.grpc.BindableService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * FLServer is Analytics Zoo PPML gRPC server used for FL based on ZooGrpcServer + * User could also call main method and parse server type to start gRPC service + * Supported types: PSI + */ +public class FLServer extends ZooGrpcServer { + private static final Logger logger = LoggerFactory.getLogger(FLServer.class); + + FLServer(String[] args, BindableService service) { + super(args, service); + configPath = "ppml-conf.yaml"; + } + FLServer(String[] args) { + this(args, null); + } + + @Override + public void parseConfig() throws IOException { + FLHelper flHelper = getConfigFromYaml(FLHelper.class, configPath); + if (flHelper != null) { + serviceList = flHelper.servicesList; + port = flHelper.serverPort; + } + for (String service : serviceList.split(",")) { + if (service.equals("psi")) { + serverServices.add(new PSIServiceImpl()); + } else if (service.equals("ps")) { + // TODO: add algorithms here + } else { + logger.warn("Type is not supported, skipped. Type: " + service); + } + } + } + + public static void main(String[] args) throws IOException, InterruptedException { + FLServer flServer = new FLServer(args); + flServer.build(); + flServer.start(); + flServer.blockUntilShutdown(); + } +} diff --git a/zoo/src/main/java/com/intel/analytics/zoo/ppml/generated/FLProto.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/generated/FLProto.java similarity index 100% rename from zoo/src/main/java/com/intel/analytics/zoo/ppml/generated/FLProto.java rename to ppml/src/main/java/com/intel/analytics/zoo/ppml/generated/FLProto.java diff --git a/zoo/src/main/java/com/intel/analytics/zoo/ppml/generated/PSIServiceGrpc.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/generated/PSIServiceGrpc.java similarity index 100% rename from zoo/src/main/java/com/intel/analytics/zoo/ppml/generated/PSIServiceGrpc.java rename to ppml/src/main/java/com/intel/analytics/zoo/ppml/generated/PSIServiceGrpc.java diff --git a/zoo/src/main/java/com/intel/analytics/zoo/ppml/generated/ParameterServerServiceGrpc.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/generated/ParameterServerServiceGrpc.java similarity index 100% rename from zoo/src/main/java/com/intel/analytics/zoo/ppml/generated/ParameterServerServiceGrpc.java rename to ppml/src/main/java/com/intel/analytics/zoo/ppml/generated/ParameterServerServiceGrpc.java diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/PSIServiceImpl.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/PSIServiceImpl.java new file mode 100644 index 00000000000..2acfb008b9a --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/PSIServiceImpl.java @@ -0,0 +1,180 @@ +/* + * Copyright 2021 The Analytics Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.zoo.ppml.psi; + +import com.intel.analytics.zoo.ppml.generated.FLProto.*; +import com.intel.analytics.zoo.ppml.generated.PSIServiceGrpc; +import io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +public class PSIServiceImpl extends PSIServiceGrpc.PSIServiceImplBase { + private static final Logger logger = LoggerFactory.getLogger(PSIServiceImpl.class); + // TODO Thread safe + protected Map psiTasks = new HashMap<>(); + // This psiCollections is + // TaskId, ClientId, UploadRequest + protected Map> psiCollections = new HashMap<>(); + HashMap clientNum = new HashMap<>(); + HashMap clientSalt = new HashMap<>(); + HashMap clientSecret = new HashMap<>(); + // Stores the seed used in shuffling for each taskId + HashMap clientShuffleSeed = new HashMap<>(); + protected int splitSize = 1000000; + + @Override + public void getSalt(SaltRequest req, StreamObserver responseObserver) { + String salt; + // Store salt + String taskId = req.getTaskId(); + if (clientSalt.containsKey(taskId)) { + salt = clientSalt.get(taskId); + } else { + salt = Utils.getRandomUUID(); + clientSalt.put(taskId, salt); + } + // Store clientNum + if (req.getClientNum() != 0 && !clientNum.containsKey(taskId)) { + clientNum.put(taskId, req.getClientNum()); + } + // Store secure + if (!clientSecret.containsKey(taskId)) { + clientSecret.put(taskId, req.getSecureCode()); + } else if (!clientSecret.get(taskId).equals(req.getSecureCode())) { + // TODO Reply empty + } + // Store random seed for shuffling + if (!clientShuffleSeed.containsKey(taskId)) { + clientShuffleSeed.put(taskId, Utils.getRandomInt()); + } + SaltReply reply = SaltReply.newBuilder().setSaltReply(salt).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + + @Override + public void uploadSet(UploadSetRequest request, + StreamObserver responseObserver) { + String taskId = request.getTaskId(); + SIGNAL signal; + if (!clientNum.containsKey(taskId)) { + signal= SIGNAL.ERROR; + logger.error("TaskId not found in server, please get salt first. " + + "TaskID:" + taskId + ", ClientID:" + request.getClientId()); + } else { + signal= SIGNAL.SUCCESS; + String clientId = request.getClientId(); + int numSplit = request.getNumSplit(); + int splitLength = request.getSplitLength(); + int totalLength = request.getTotalLength(); + if(!psiCollections.containsKey(taskId)){ + psiCollections.put(taskId, new HashMap()); + } + if(!psiCollections.get(taskId).containsKey(clientId)){ + if(psiCollections.get(taskId).size() >= clientNum.get(taskId)) { + logger.error("Too many clients, already has " + + psiCollections.get(taskId).keySet() + + ". The new one is " + clientId); + } + psiCollections.get(taskId).put(clientId, new String[totalLength]); + } + String[] collectionStorage = psiCollections.get(taskId).get(clientId); + String[] ids = request.getHashedIDList().toArray(new String[request.getHashedIDList().size()]); + int split = request.getSplit(); + // TODO: verify requests' splits are unique. + System.arraycopy(ids, 0, collectionStorage, split * splitLength, ids.length); + logger.info("ClientId" + clientId + ",split: " + split + ", numSplit: " + numSplit + "."); + if (split == numSplit - 1) { + synchronized (psiTasks) { + try { + if (psiTasks.containsKey(taskId)) { + logger.info("Adding " + (psiTasks.get(taskId).numCollection() + 1) + + "th collections to " + taskId + "."); + long st = System.currentTimeMillis(); + psiTasks.get(taskId).addCollection(collectionStorage); + logger.info("Added " + (psiTasks.get(taskId).numCollection()) + + "th collections to " + taskId + ". Find Intersection time cost: " + (System.currentTimeMillis()-st) + " ms"); + } else { + logger.info("Adding 1th collections."); + PsiIntersection pi = new PsiIntersection(clientNum.get(taskId), + clientShuffleSeed.get(taskId)); + pi.addCollection(collectionStorage); + psiTasks.put(taskId, pi); + logger.info("Added 1th collections."); + } + psiCollections.get(taskId).remove(clientId); + } catch (InterruptedException | ExecutionException e){ + logger.error(e.getMessage()); + signal= SIGNAL.ERROR; + } catch (IllegalArgumentException iae) { + logger.error("TaskId " + taskId + ": Too many collections from client."); + logger.error("Current client ids are " + psiCollections.get(taskId).keySet()); + logger.error(iae.getMessage()); + throw iae; + } + } + } + } + + UploadSetResponse response = UploadSetResponse.newBuilder() + .setTaskId(taskId) + .setStatus(signal) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + + @Override + public void downloadIntersection(DownloadIntersectionRequest request, + StreamObserver responseObserver) { + String taskId = request.getTaskId(); + SIGNAL signal = SIGNAL.SUCCESS; + if (psiTasks.containsKey(taskId)) { + try { + List intersection = psiTasks.get(taskId).getIntersection(); + int split = request.getSplit(); + int numSplit = Utils.getTotalSplitNum(intersection, splitSize); + List splitIntersection = Utils.getSplit(intersection, split, numSplit, splitSize); + DownloadIntersectionResponse response = DownloadIntersectionResponse.newBuilder() + .setTaskId(taskId) + .setStatus(signal) + .setSplit(split) + .setNumSplit(numSplit) + .setTotalLength(intersection.size()) + .setSplitLength(splitSize) + .addAllIntersection(splitIntersection).build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } catch (InterruptedException e) { + logger.error(e.getMessage()); + signal = SIGNAL.ERROR; + DownloadIntersectionResponse response = DownloadIntersectionResponse.newBuilder() + .setTaskId(taskId) + .setStatus(signal).build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + } + } + + +} + diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/PsiIntersection.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/PsiIntersection.java new file mode 100644 index 00000000000..9ce7048d1d4 --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/PsiIntersection.java @@ -0,0 +1,140 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml.psi; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.*; + +public class PsiIntersection { + public final int maxCollection; + public final int shuffleSeed; + + protected final int nThreads = Integer.parseInt(System.getProperty( + "PsiThreads", "6")); + + protected ExecutorService pool = Executors.newFixedThreadPool(nThreads); + + public PsiIntersection(int maxCollection, int shuffleSeed) { + this.maxCollection = maxCollection; + this.shuffleSeed = shuffleSeed; + } + + protected List collections = new ArrayList(); + protected List intersection; + + public int numCollection() { + return collections.size(); + } + + public void addCollection( + String[] collection) throws InterruptedException, ExecutionException{ + synchronized (this) { + if (collections.size() == maxCollection) { + throw new IllegalArgumentException("Collection is full."); + } + collections.add(collection); + if (collections.size() >= maxCollection) { + // TODO: sort by collections' size + String[] current = collections.get(0); + for(int i = 1; i < maxCollection - 1; i++){ + Arrays.parallelSort(current); + current = findIntersection(current, collections.get(i)) + .toArray(new String[intersection.size()]); + } + Arrays.parallelSort(current); + List result = findIntersection(current, collections.get(maxCollection - 1)); + Utils.shuffle(result, shuffleSeed); + intersection = result; + this.notifyAll(); + } + } + } + + // Join a with b, a should be sorted. + private static class FindIntersection implements Callable> { + protected String[] a; + protected String[] b; + protected int bStart; + protected int length; + + public FindIntersection(String[] a, + String[] b, + int bStart, + int length) { + this.a = a; + this.b = b; + this.bStart = bStart; + this.length = length; + } + + @Override + public List call() { + return findIntersection(a, b, bStart, length); + } + + protected static List findIntersection( + String[] a, + String[] b, + int start, + int length){ + ArrayList intersection = new ArrayList(); + for(int i = start; i < length + start; i++) { + if (Arrays.binarySearch(a, b[i]) >= 0){ + intersection.add(b[i]); + } + } + return intersection; + } + } + + protected List findIntersection( + String[] a, + String[] b) throws InterruptedException, ExecutionException{ + int[] splitPoints = new int[nThreads + 1]; + int extractLen = b.length - nThreads * (b.length / nThreads); + for(int i = 1; i < splitPoints.length; i++) { + splitPoints[i] = b.length / nThreads * i; + if (i <= extractLen) { + splitPoints[i] += i; + } else { + splitPoints[i] += extractLen; + } + } + + Future>[] futures = new Future[nThreads]; + for(int i = 0; i < nThreads; i++) { + futures[i] = pool.submit(new FindIntersection(a, b, splitPoints[i], + splitPoints[i + 1] - splitPoints[i])); + } + List intersection = futures[0].get(); + for(int i = 1; i < nThreads; i++) { + intersection.addAll(futures[i].get()); + } + return intersection; + } + + public List getIntersection() throws InterruptedException{ + synchronized (this) { + if(null == intersection) { + this.wait(); + } + return intersection; + } + } +} diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/Utils.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/Utils.java new file mode 100644 index 00000000000..5ec0aec019e --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/Utils.java @@ -0,0 +1,58 @@ +/* + * Copyright 2021 The Analytics Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml.psi; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; + + +public class Utils { + private static final Logger logger = LoggerFactory.getLogger(Utils.class); + + // TODO for XGboost + public ArrayList> partition(int partitionNum, ArrayList dataset) { + return new ArrayList<>(); + } + + public static void shuffle(List array,int seed){ + Collections.shuffle(array,new Random(seed)); + } + public static String getRandomUUID() { + return UUID.randomUUID().toString(); + } + + public static int getRandomInt() { + Random rand = new Random(); + return rand.nextInt(); + } + + + public static int getTotalSplitNum(List list, int splitSize) { + return (int)Math.ceil((double)list.size() / splitSize); + } + + public static List getSplit(List list, int split, int totalSplitNum, int splitSize) { + if (split < totalSplitNum - 1) { + return list.subList(split * splitSize, (split + 1) * splitSize); + } else { + return list.subList(split * splitSize, list.size()); + } + + } +} diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/BenchmarkClient.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/BenchmarkClient.java new file mode 100644 index 00000000000..46066348859 --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/BenchmarkClient.java @@ -0,0 +1,127 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml.psi.test; + +import com.intel.analytics.zoo.ppml.FLClient; +import com.intel.analytics.zoo.ppml.generated.PSIServiceGrpc; +import io.grpc.Channel; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.TimeUnit; + +public class BenchmarkClient { + private static final Logger logger = LoggerFactory.getLogger(BenchmarkClient.class); + + public static void main(String[] args) throws Exception { + String taskID; + String target; + int idSize; + int startNum; + // Number of arguments to be passed. + int argNum = 5; + if (args.length == 0) { + logger.info("No argument passed, using default parameters."); + taskID = "taskID"; + target = "localhost:50051"; + idSize = 10000; + startNum = 0; + } else if (args.length < argNum || args.length > argNum + 1) { + logger.info("Error: detecting " + Integer.toString(args.length) + " arguments. Expecting " + Integer.toString(argNum) + "."); + logger.info("Usage: BenchmarkClient taskID ServerIP ServerPort"); + taskID = ""; + target = ""; + idSize = 0; + startNum = 0; + System.exit(0); + } else { + taskID = args[0]; + target = args[1] + ":" + args[2]; + idSize = Integer.parseInt(args[3]); + startNum = Integer.parseInt(args[4]); + } + logger.info("TaskID is: " + taskID); + logger.info("Accessing service at: " + target); + + // Example code for client + // Quick lookup for the plaintext of hashed ids + List ids = new ArrayList(idSize); + long stproduce = System.currentTimeMillis(); + for (int i = startNum; i < idSize; i++) { + ids.add(i-startNum, String.valueOf(i)); + } + long etproduce = System.currentTimeMillis(); + logger.info("### Time of producing data: " + (etproduce - stproduce) + " ms ###"); + HashMap hashedIds = new HashMap<>(); + List hashedIdArray; + String salt; + + // Create a communication channel to the server, known as a Channel. Channels are thread-safe + // and reusable. It is common to create channels at the beginning of your application and reuse + // them until the application shuts down. + ManagedChannel channel = ManagedChannelBuilder.forTarget(target) + // Channels are secure by default (via SSL/TLS). + //extend message size of server to 200M to avoid size conflict + .maxInboundMessageSize(Integer.MAX_VALUE) + .usePlaintext() + .build(); + try { + String[] arg = {"-c", BenchmarkClient.class.getClassLoader() + .getResource("psi/psi-conf.yaml").getPath()}; + FLClient client = new FLClient(arg); + client.build(); + // Get salt from Server + salt = client.getSalt(); + logger.info("Client get Slat=" + salt); + // Hash(IDs, salt) into hashed IDs + long shash = System.currentTimeMillis(); + hashedIdArray = TestUtils.parallelToSHAHexString(ids, salt); + for (int i = 0; i < ids.size(); i++) { + logger.debug(hashedIdArray.get(i)); + hashedIds.put(hashedIdArray.get(i), ids.get(i)); + } + long ehash = System.currentTimeMillis(); + logger.info("### Time of hash data: " + (ehash - shash) + " ms ###"); + logger.info("HashedIDs Size = " + hashedIdArray.size()); + long supload = System.currentTimeMillis(); + client.uploadSet(hashedIdArray); + long eupload = System.currentTimeMillis(); + logger.info("### Time of upload data: " + (eupload - supload) + " ms ###"); + logger.info("upload hashed id successfully"); + List intersection; + + long sdownload = System.currentTimeMillis(); + intersection = client.downloadIntersection(); + long edownload = System.currentTimeMillis(); + logger.info("### Time of download data: " + (edownload - sdownload) + " ms ###"); + logger.info("Intersection successful. Total id(s) in intersection is " + intersection.size()); + + } finally { + // ManagedChannels use resources like threads and TCP connections. To prevent leaking these + // resources the channel should be shut down when it will no longer be used. If it may be used + // again leave it running. + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } +} + + diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/Client.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/Client.java new file mode 100644 index 00000000000..84668088d53 --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/Client.java @@ -0,0 +1,88 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml.psi.test; + +import com.intel.analytics.zoo.ppml.FLClient; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.TimeUnit; + + +public class Client { + private static final Logger logger = LoggerFactory.getLogger(Client.class); + + public static void main(String[] args) throws Exception { + + int max_wait = 20; + // Example code for client + int idSize = 11; + // Quick lookup for the plaintext of hashed ids + HashMap data = TestUtils.genRandomHashSet(idSize); + HashMap hashedIds = new HashMap<>(); + List hashedIdArray; + String salt; + List ids = new ArrayList<>(data.keySet()); + + // Create a communication channel to the server, known as a Channel. Channels are thread-safe + // and reusable. It is common to create channels at the beginning of your application and reuse + // them until the application shuts down. + String[] arg = {"-c", BenchmarkClient.class.getClassLoader() + .getResource("psi/psi-conf.yaml").getPath()}; + FLClient client = new FLClient(arg); + try { + client.build(); + // Get salt from Server + salt = client.getSalt(); + logger.debug("Client get Slat=" + salt); + // Hash(IDs, salt) into hashed IDs + hashedIdArray = TestUtils.parallelToSHAHexString(ids, salt); + for (int i = 0; i < ids.size(); i++) { + hashedIds.put(hashedIdArray.get(i), ids.get(i)); + } + logger.debug("HashedIDs Size = " + hashedIds.size()); + client.uploadSet(hashedIdArray); + List intersection; + + while (max_wait > 0) { + intersection = client.downloadIntersection(); + if (intersection == null) { + logger.info("Wait 1000ms"); + Thread.sleep(1000); + } else { + logger.info("Intersection successful. Intersection's size is " + intersection.size() + "."); + break; + } + max_wait--; + } + } catch (Exception e) { + e.printStackTrace(); + } finally{ + // ManagedChannels use resources like threads and TCP connections. To prevent leaking these + // resources the channel should be shut down when it will no longer be used. If it may be used + // again leave it running. + client.getChannel().shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } +} + + diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/NetworkCheckClient.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/NetworkCheckClient.java new file mode 100644 index 00000000000..e9f7846a687 --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/NetworkCheckClient.java @@ -0,0 +1,110 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml.psi.test; + +import com.intel.analytics.zoo.ppml.FLClient; +import com.intel.analytics.zoo.ppml.generated.PSIServiceGrpc; +import io.grpc.Channel; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.TimeUnit; + +public class NetworkCheckClient{ + public static void main(String[] args) throws Exception { + String taskID; + String target; + // Number of arguments to be passed. + int argNum = 3; + if (args.length == 0) { + //logger.info("No argument passed, using default parameters."); + taskID = "taskID"; + target = "localhost:50051"; + } else if (args.length < argNum || args.length > argNum + 1) { + //logger.info("Error: detecting " + Integer.toString(args.length) + " arguments. Expecting " + Integer.toString(argNum) + "."); + //logger.info("Usage: BenchmarkClient taskID ServerIP ServerPort"); + taskID = ""; + target = ""; + System.exit(0); + } else { + taskID = args[0]; + target = args[1] + ":" + args[2]; + } + //logger.info("TaskID is: " + taskID); + //logger.info("Accessing service at: " + target); + + int max_wait = 2000; + // Example code for client + int idSize = 150000; + // Quick lookup for the plaintext of hashed ids + HashMap data = TestUtils.getRandomHashSetOfStringForFiveFixed(idSize);//Utils.genRandomHashSet(idSize); + HashMap hashedIds = new HashMap<>(); + List hashedIdArray; + String salt; + List ids = new ArrayList<>(data.keySet()); + + // Create a communication channel to the server, known as a Channel. Channels are thread-safe + // and reusable. It is common to create channels at the beginning of your application and reuse + // them until the application shuts down. + ManagedChannel channel = ManagedChannelBuilder.forTarget(target) + // Channels are secure by default (via SSL/TLS). + //extend message size of server to 200M to avoid size conflict + .maxInboundMessageSize(209715200) + .usePlaintext() + .build(); + try { + String[] arg = {"-c", BenchmarkClient.class.getClassLoader() + .getResource("psi/psi-conf.yaml").getPath()}; + FLClient client = new FLClient(arg); + client.build(); + // Get salt from Server + salt = client.getSalt(); + //logger.debug("Client get Slat=" + salt); + // Hash(IDs, salt) into hashed IDs + hashedIdArray = TestUtils.parallelToSHAHexString(ids, salt); + for (int i = 0; i < ids.size(); i++) { + hashedIds.put(hashedIdArray.get(i), ids.get(i)); + } + //logger.debug("HashedIDs Size = " + hashedIds.size()); + client.uploadSet(hashedIdArray); + List intersection; + + while (max_wait > 0) { + intersection = client.downloadIntersection(); + if (intersection == null) { + //logger.info("Wait 1000ms"); + Thread.sleep(1000); + } else { + System.out.println("Intersection successful. Intersection's size is " + intersection.size() + "."); + break; + } + max_wait--; + } + + } finally { + // ManagedChannels use resources like threads and TCP connections. To prevent leaking these + // resources the channel should be shut down when it will no longer be used. If it may be used + // again leave it running. + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } +} + + diff --git a/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/TestUtils.java b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/TestUtils.java new file mode 100644 index 00000000000..9089078cec8 --- /dev/null +++ b/ppml/src/main/java/com/intel/analytics/zoo/ppml/psi/test/TestUtils.java @@ -0,0 +1,267 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml.psi.test; + + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.*; +import java.util.concurrent.*; + +public class TestUtils { + private static final Logger logger = LoggerFactory.getLogger(TestUtils.class); + /*** + * Gen random HashMap for test + * @param size HashMap size, int + * @return + */ + public static HashMap genRandomHashSet(int size) { + HashMap data = new HashMap<>(); + Random rand = new Random(); + for (int i = 0; i < size; i++) { + String name = "User_" + rand.nextInt(); + data.put(name, Integer.toString(i)); + } + logger.info("IDs are: "); + for (Map.Entry element : data.entrySet()) { + logger.info(element.getKey() + ", " + element.getValue()); + } + return data; + } + + protected static final int nThreads = Integer.parseInt(System.getProperty( + "PsiThreads", "6")); + + + public static List parallelToSHAHexString( + List ids, + String salt) throws InterruptedException, ExecutionException { + return parallelToSHAHexString(ids, salt, 256, 32); + } + + public static List parallelToSHAHexString( + List ids, + String salt, + int length, + int paddingSize) throws InterruptedException, ExecutionException { + String[] idsArray = ids.toArray(new String[ids.size()]); + String[] hashedIds = parallelToSHAHexString(idsArray, salt, length, paddingSize); +// return new ArrayList(Arrays.asList(hashedIds)); + return Arrays.asList(hashedIds); + } + + public static String[] parallelToSHAHexString( + String[] ids, + String salt) throws InterruptedException, ExecutionException { + return parallelToSHAHexString(ids, salt, 256, 32); + } + + public static String[] parallelToSHAHexString( + String[] ids, + String salt, + int length, + int paddingSize) throws InterruptedException, ExecutionException { + String[] output = new String[ids.length]; + ExecutorService pool = Executors.newFixedThreadPool(nThreads); + int extractLen = ids.length - nThreads * (ids.length / nThreads); + int average = ids.length / nThreads; + Future[] futures = new Future[nThreads]; + for(int i = 0; i < nThreads - 1; i++) { + futures[i] = pool.submit(new StringToSHAHex(ids, average * i, + average, output, salt, length, paddingSize)); + } + futures[nThreads - 1] = pool.submit(new StringToSHAHex(ids, average * (nThreads - 1), + average + extractLen, output, salt, length, paddingSize)); + + for(int i = 0; i < nThreads; i++) { + futures[i].get(); + } + pool.shutdown(); + return output; + } + + private static class StringToSHAHex implements Callable { + protected String[] src; + protected int start; + protected int length; + protected String[] dest; + protected int paddingSize; + protected String salt; + protected MessageDigest generator; + + public StringToSHAHex( + String[] src, + int start, + int length, + String[] dest) { + this(src, start, length, dest, "", 256, 32); + } + + public StringToSHAHex( + String[] src, + int start, + int length, + String[] dest, + String salt, + int shaLength, + int paddingSize) { + this.src = src; + this.start = start; + this.length = length; + this.dest = dest; + this.paddingSize = paddingSize; + this.salt = salt; + try { + this.generator = MessageDigest.getInstance("SHA-" + shaLength); + } catch (NoSuchAlgorithmException nsae) { + nsae.printStackTrace(); + throw new RuntimeException(nsae); + } + } + + @Override + public Integer call() { + toSHAHexString(); + return 0; + } + + protected void toSHAHexString() { + for(int i = start; i < length + start; i++) { + dest[i] = toHexString( + generator.digest((src[i] + salt).getBytes(StandardCharsets.UTF_8)), paddingSize); + } + } + } + + + + public static byte[] getSecurityRandomBytes() { + SecureRandom random = new SecureRandom(); + byte[] randBytes = new byte[20]; + random.nextBytes(randBytes); + return randBytes; + } + + public static byte[] getSHA(String input) throws NoSuchAlgorithmException { + return getSHA(input, 256); + } + + public static byte[] int2Bytes(int value) { + byte[] src = new byte[4]; + src[3] = (byte) ((value>>24) & 0xFF); + src[2] = (byte) ((value>>16) & 0xFF); + src[1] = (byte) ((value>>8) & 0xFF); + src[0] = (byte) (value & 0xFF); + return src; + } + + /*** + * Get random HashMap for test of random string + * @param size HashMap size, int + * @return + */ + public static HashMap getRandomHashSetOfString(int size) { + HashMap data = new HashMap<>(); + for (int i = 0; i < size; i++) { + String name = toHexString(int2Bytes(i)); + data.put(name, Integer.toString(i)); + } + logger.info("IDs are: "); + for (Map.Entry element : data.entrySet()) { + logger.info(element.getKey() + ", " + element.getValue()); + } + return data; + } + + public static HashMap getRandomHashSetOfStringForFiveFixed(int size) { + HashMap data = new HashMap<>(); + Random rand = new Random(); + // put several constant for test + String nameTest = "User_11111111111111111111111111111";//randomBytes; + data.put(nameTest, Integer.toString(0)); + nameTest = "User_111111111111111111111111122222";//randomBytes; + data.put(nameTest, Integer.toString(1)); + nameTest = "User_11111111111111111111111133333";//randomBytes; + data.put(nameTest, Integer.toString(2)); + nameTest = "User_11111111111111111111111144444";//randomBytes; + data.put(nameTest, Integer.toString(3)); + nameTest = "User_11111111111111111111111155555";//randomBytes; + data.put(nameTest, Integer.toString(4)); + for (int i = 5; i < size; i++) { + //String randomBytes = new String(getSecurityRandomBytes()); + String name = toHexString(int2Bytes(i));//randomBytes; + data.put(name, Integer.toString(i)); + } + logger.info("IDs are: "); + for (Map.Entry element : data.entrySet()) { + logger.info(element.getKey() + ", " + element.getValue()); + } + return data; + } + + + /** + * Get SHA hash result of given string input + * + * @param input string input + * @param length bit length, e.g., 128 and 256 + * @return + * @throws NoSuchAlgorithmException + */ + public static byte[] getSHA(String input, int length) throws NoSuchAlgorithmException { + return MessageDigest.getInstance("SHA-" + length).digest(input.getBytes(StandardCharsets.UTF_8)); + } + + public static String toHexString(byte[] hash) { + return toHexString(hash, 32); + } + + public static String toHexString(byte[] hash, int paddingSize) { + // Convert byte array into signum representation + BigInteger number = new BigInteger(1, hash); + + // Convert message digest into hex value + StringBuilder hexString = new StringBuilder(number.toString(16)); + + // Pad with leading zeros + while (hexString.length() < paddingSize) { + hexString.insert(0, '0'); + } + + return hexString.toString(); + } + + public static boolean checkHash(byte[] bytes, String hashstr) { + // transfor bytes to String type + StringBuffer hexValues = new StringBuffer(); + for (int i = 0; i < bytes.length; i++) { + int val = ((int) bytes[i]) & 0xff; + if (val < 16) { + hexValues.append("0"); + } + hexValues.append(Integer.toHexString(val)); + } + String bytestr = hexValues.toString(); + return bytestr.equals(hashstr); + } +} diff --git a/ppml/src/main/proto/FLProto.proto b/ppml/src/main/proto/FLProto.proto new file mode 100644 index 00000000000..7e08cf39da3 --- /dev/null +++ b/ppml/src/main/proto/FLProto.proto @@ -0,0 +1,219 @@ +// +// Copyright 2018 Analytics Zoo Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +syntax = 'proto3'; + + +//option java_multiple_files = true; +option java_package = "com.intel.analytics.zoo.ppml.generated"; +option java_outer_classname = "FLProto"; + +// PSI proto +service PSIService { + // Gives SHA256 Hash salt + rpc getSalt (SaltRequest) returns (SaltReply) {} + rpc uploadSet (UploadSetRequest) returns (UploadSetResponse) {} + rpc downloadIntersection (DownloadIntersectionRequest) returns (DownloadIntersectionResponse) {} +} + +enum SIGNAL { + SUCCESS = 0; + WAIT = 1; + TIMEOUT = 2; + EMPTY_INPUT = 3; + ERROR = 4; +} + +message SaltRequest { + string task_id = 1; + int32 client_num = 2; + string secure_code = 3; +} + +message SaltReply { + string salt_reply = 1; +} + +message UploadSetRequest { + string task_id = 1; + string client_id = 2; + int32 split = 3; + int32 num_split = 4; + int32 split_length = 5; + int32 total_length = 6; + repeated string hashedID = 7; +} + +message UploadSetResponse { + string task_id = 1; + SIGNAL status = 2; +} + +message DownloadIntersectionRequest { + string task_id = 1; + int32 split = 2; +} + +message DownloadIntersectionResponse { + string task_id = 1; + SIGNAL status = 2; + int32 split = 3; + int32 num_split = 4; + int32 split_length = 5; + int32 total_length = 6; + repeated string intersection = 7; +} + + +// Parameter Server Proto +service ParameterServerService { + // NN + rpc UploadTrain(UploadRequest) returns (UploadResponse) {} + rpc DownloadTrain(DownloadRequest) returns (DownloadResponse) {} + rpc UploadEvaluate(EvaluateRequest) returns (EvaluateResponse) {} + // Gradient Boosting Tree + rpc UploadSplitTrain(UploadSplitRequest) returns (UploadResponse) {} + rpc DownloadSplitTrain(DownloadSplitRequest) returns (DownloadSplitResponse) {} + rpc Register(RegisterRequest) returns (RegisterResponse) {} + rpc UploadTreeEval(UploadTreeEvalRequest) returns (UploadResponse) {} + rpc UploadTreeLeaves(UploadTreeLeavesRequest) returns (UploadResponse) {} + rpc PredictTree(PredictTreeRequest) returns (PredictTreeResponse) {} +} +// +message FloatTensor { + repeated int32 shape = 1; + repeated float tensor = 2; +} +// +message Table { + TableMetaData metaData = 1; + map table = 2; +} +// +message TableMetaData { + string name = 1; + int32 version = 2; +} + +message TreeLeaves { + string treeID = 1; + repeated int32 leafIndex = 2; + repeated float leafOutput = 3; +} + +message UploadTreeLeavesRequest { + string clientuuid = 1; + TreeLeaves treeLeaves = 2; +} + +message DataSplit { + string treeID = 1; + string nodeID = 2; + int32 featureID = 3; + float splitValue = 4; + float gain = 5; + int32 setLength = 6; + repeated int32 itemSet = 7; + string clientUid = 8; +} + + +message TreePredict { + string treeID = 1; + repeated bool predicts = 2; +} + +message BoostPredict { + repeated TreePredict predicts = 1; +} + +message BoostEval { + repeated TreePredict evaluates = 1; +} + +message DownloadRequest { + TableMetaData metaData = 1; +} + +message DownloadResponse { + Table data = 1; + string response = 2; + int32 code = 3; +} + +message UploadRequest { + string clientuuid = 1; + Table data = 2; +} + +message UploadResponse { + string response = 1; + int32 code = 2; +} + +message RegisterRequest { + string clientuuid = 1; + string token = 2; +} + +message RegisterResponse { + string response = 1; + int32 code = 2; +} + +message EvaluateRequest { + string clientuuid = 1; + Table data = 2; + bool last = 3; +} + +message EvaluateResponse { + Table data = 1; + string response = 2; + int32 code = 3; +} + +message UploadTreeEvalRequest { + string clientuuid = 1; + int32 version = 2; + repeated BoostEval treeEval = 3; +} + +message UploadSplitRequest { + string clientuuid = 1; + DataSplit split = 2; +} + +message PredictTreeRequest { + string clientuuid = 1; + repeated BoostEval treeEval = 2; + int32 bsVersion = 3; +} + +message PredictTreeResponse { + Table result = 1; +} + +message DownloadSplitRequest { + string clientuuid = 1; + string treeID = 2; + string nodeID = 3; +} + +message DownloadSplitResponse { + DataSplit split = 1; + string response = 2; + int32 code = 3; +} diff --git a/ppml/src/main/resources/psi/psi-conf.yaml b/ppml/src/main/resources/psi/psi-conf.yaml new file mode 100644 index 00000000000..444356b1f98 --- /dev/null +++ b/ppml/src/main/resources/psi/psi-conf.yaml @@ -0,0 +1,8 @@ +servicesList: psi + +# Server property +# serverPort: + +# Client property +# clientTarget: +# taskID: \ No newline at end of file diff --git a/ppml/src/main/scala/com/intel/analytics/zoo/ppml/VFLServer.scala b/ppml/src/main/scala/com/intel/analytics/zoo/ppml/VFLServer.scala new file mode 100644 index 00000000000..1abf00c9f61 --- /dev/null +++ b/ppml/src/main/scala/com/intel/analytics/zoo/ppml/VFLServer.scala @@ -0,0 +1,20 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.ppml + +object VFLServer { +} diff --git a/zoo/src/main/java/com/intel/analytics/zoo/grpc/AbstractZooGrpc.java b/zoo/src/main/java/com/intel/analytics/zoo/grpc/AbstractZooGrpc.java new file mode 100644 index 00000000000..f43c65c5b5e --- /dev/null +++ b/zoo/src/main/java/com/intel/analytics/zoo/grpc/AbstractZooGrpc.java @@ -0,0 +1,60 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.grpc; + +import com.intel.analytics.zoo.utils.ConfigParser; +import org.apache.commons.cli.*; + +import java.io.IOException; + +public abstract class AbstractZooGrpc { + protected String[] args; + protected Options options; + protected String configPath; + protected CommandLine cmd; + protected String serviceList = ""; + + protected T getConfigFromYaml(Class valueType, String defaultConfigPath) + throws IOException { + options = new Options(); + options.addOption(new Option( + "c", "config", true, "config path")); + CommandLineParser parser = new DefaultParser(); + HelpFormatter formatter = new HelpFormatter(); + cmd = null; + + try { + cmd = parser.parse(options, args); + } catch (ParseException e) { + System.out.println(e.getMessage()); + formatter.printHelp("utility-name", options); + System.exit(1); + } + assert cmd != null; + configPath = cmd.getOptionValue("config", defaultConfigPath); + if (configPath != null) { + // config YAML passed, use config YAML first, command-line could overwrite + assert valueType != null; + return ConfigParser.loadConfigFromPath(configPath, valueType); + } + else { + System.out.println("Config is not provided, using default"); + return null; + } + + } +} diff --git a/zoo/src/main/java/com/intel/analytics/zoo/grpc/ZooGrpcClient.java b/zoo/src/main/java/com/intel/analytics/zoo/grpc/ZooGrpcClient.java new file mode 100644 index 00000000000..f2ce520761d --- /dev/null +++ b/zoo/src/main/java/com/intel/analytics/zoo/grpc/ZooGrpcClient.java @@ -0,0 +1,68 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.grpc; + + +import io.grpc.*; +import org.apache.log4j.Logger; + + +import java.io.IOException; +import java.util.UUID; +import java.util.function.Function; + +/** + * All Analytics Zoo gRPC clients are based on ZooGrpcClient + * To implement specific gRPC client, overwrite parseConfig() and loadServices() method + */ +public class ZooGrpcClient extends AbstractZooGrpc{ + protected static final Logger logger = Logger.getLogger(ZooGrpcClient.class.getName()); + protected String target; + protected final String clientUUID; + protected ManagedChannel channel; + public ZooGrpcClient(String[] args) { + clientUUID = UUID.randomUUID().toString(); + this.args = args; + } + protected void parseConfig() throws IOException {} + + public void loadServices() {} + + public ManagedChannel getChannel() { + return channel; + } + public void build() throws IOException { + parseConfig(); + + channel = ManagedChannelBuilder.forTarget(target) + // Channels are secure by default (via SSL/TLS). + .usePlaintext() + .build(); + loadServices(); + } + public O call(Function f, I msg) { + O r = null; + try { + r = f.apply(msg); + } catch (Exception e) { + logger.warn("failed"); + } finally { + return r; + } + + } +} diff --git a/zoo/src/main/java/com/intel/analytics/zoo/grpc/ZooGrpcServer.java b/zoo/src/main/java/com/intel/analytics/zoo/grpc/ZooGrpcServer.java index ef1a5ef59d5..42df6efb350 100644 --- a/zoo/src/main/java/com/intel/analytics/zoo/grpc/ZooGrpcServer.java +++ b/zoo/src/main/java/com/intel/analytics/zoo/grpc/ZooGrpcServer.java @@ -1,15 +1,15 @@ /* - * Copyright 2018 Analytics Zoo Authors. + * Copyright 2021 The Analytic Zoo Authors * - * Licensed under the Apache License, Version 2.0 (the "License"); + * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software + * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ @@ -17,38 +17,93 @@ package com.intel.analytics.zoo.grpc; - import io.grpc.BindableService; import io.grpc.Server; import io.grpc.ServerBuilder; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; + +import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; +import org.apache.log4j.Logger; +import javax.net.ssl.SSLException; +import java.io.File; import java.io.IOException; +import java.util.LinkedList; import java.util.concurrent.TimeUnit; -import java.util.logging.Logger; - /** - * Zoo gRPC server class - * After protobuf generated and service is implemented, service could be passed to ZooGrpcServer - * to start serving request. + * All Analytics Zoo gRPC servers are based on ZooGrpcServer + * To implement specific gRPC server, overwrite parseConfig() method + * This class could also be directly used for start a single service */ -public class ZooGrpcServer { - private static final Logger logger = Logger.getLogger(ZooGrpcServer.class.getName()); - private final int port; - private final Server server; +public class ZooGrpcServer extends AbstractZooGrpc{ + protected static final Logger logger = Logger.getLogger(ZooGrpcServer.class.getName()); + protected int port; + protected Server server; + protected LinkedList serverServices; + // TLS arguments + String certChainFilePath; + String privateKeyFilePath; + String trustCertCollectionFilePath; + + + /** + * One Server could support multiple servives. + * Also support a single service constructor + * @param service + */ public ZooGrpcServer(BindableService service) { - this(8980, "zoo-grpc-conf.yaml", service); + this(null, service); } - public ZooGrpcServer(String configPath, BindableService service) { - this(8980, configPath, service); + public ZooGrpcServer(String[] args, BindableService service) { + serverServices = new LinkedList<>(); + if (service != null) { + serverServices.add(service); + } + this.args = args; + } + public ZooGrpcServer(String[] args) { + this(args, null); + } + + public void parseConfig() throws IOException {} + + /** Entrypoint of ZooGrpcServer */ - public ZooGrpcServer(int port, String configPath, BindableService service) { - this.port = port; - server = ServerBuilder.forPort(port) - .addService(service) - .build(); + public void build() throws IOException { + parseConfig(); + ServerBuilder builder = ServerBuilder.forPort(port); + for (BindableService bindableService : serverServices) { + builder.addService(bindableService); + } + server = builder.maxInboundMessageSize(Integer.MAX_VALUE).build(); + } + + void buildWithTls() throws IOException { + parseConfig(); + NettyServerBuilder serverBuilder = NettyServerBuilder.forPort(port); + for (BindableService bindableService : serverServices) { + serverBuilder.addService(bindableService); + } + if (certChainFilePath != null && privateKeyFilePath != null) { + serverBuilder.sslContext(getSslContext()); + } + server = serverBuilder.build(); } + SslContext getSslContext() throws SSLException { + SslContextBuilder sslClientContextBuilder = SslContextBuilder.forServer(new File(certChainFilePath), + new File(privateKeyFilePath)); + if (trustCertCollectionFilePath != null) { + sslClientContextBuilder.trustManager(new File(trustCertCollectionFilePath)); + sslClientContextBuilder.clientAuth(ClientAuth.REQUIRE); + } + return GrpcSslContexts.configure(sslClientContextBuilder).build(); + } + /** Start serving requests. */ public void start() throws IOException { diff --git a/zoo/src/main/java/com/intel/analytics/zoo/utils/ConfigParser.java b/zoo/src/main/java/com/intel/analytics/zoo/utils/ConfigParser.java new file mode 100644 index 00000000000..146b192261e --- /dev/null +++ b/zoo/src/main/java/com/intel/analytics/zoo/utils/ConfigParser.java @@ -0,0 +1,50 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.utils; + + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; + +import java.io.IOException; + +/** + * ConfigParser has static method to read config formatted in JavaBean class from YAML file + */ +public class ConfigParser { + static ObjectMapper objectMapper; + static { + objectMapper = new ObjectMapper(new YAMLFactory()); + objectMapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true); + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + } + + public static T loadConfigFromPath(String configPath, Class valueType) + throws IOException { + return objectMapper.readValue(new java.io.File(configPath), valueType); + } + public static T loadConfigFromString(String configString, Class valueType) + throws JsonProcessingException { + + return objectMapper.readValue(configString, valueType); + } +} \ No newline at end of file diff --git a/zoo/src/test/java/com/intel/analytics/zoo/utils/ConfigParserTest.java b/zoo/src/test/java/com/intel/analytics/zoo/utils/ConfigParserTest.java new file mode 100644 index 00000000000..d9717bcb389 --- /dev/null +++ b/zoo/src/test/java/com/intel/analytics/zoo/utils/ConfigParserTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.utils; + +import com.fasterxml.jackson.core.JsonProcessingException; +import org.junit.Assert; +import org.junit.Test; + +public class ConfigParserTest { + @Test + public void testConfigParserFromString() throws JsonProcessingException { + String testString = String.join("\n", + "stringProp: abc", + "intProp: 123", + "boolProp: true"); + TestHelper testHelper = ConfigParser.loadConfigFromString(testString, TestHelper.class); + Assert.assertEquals(testHelper.intProp, 123); + Assert.assertEquals(testHelper.boolProp, true); + Assert.assertEquals(testHelper.stringProp, "abc"); + } + @Test + public void testConfigParserFromStringWithEmptyBool() throws JsonProcessingException { + String testString = String.join("\n", + "stringProp: abc", + "intProp: 123"); + TestHelper testHelper = ConfigParser.loadConfigFromString(testString, TestHelper.class); + Assert.assertEquals(testHelper.intProp, 123); + Assert.assertEquals(testHelper.boolProp, false); + Assert.assertEquals(testHelper.stringProp, "abc"); + } + @Test + public void testConfigParserFromStringWithEmptyString() throws JsonProcessingException { + String testString = String.join("\n", + "boolProp: true", + "intProp: 123"); + TestHelper testHelper = ConfigParser.loadConfigFromString(testString, TestHelper.class); + Assert.assertEquals(testHelper.intProp, 123); + Assert.assertEquals(testHelper.boolProp, true); + Assert.assertEquals(testHelper.stringProp, null); + } + @Test + public void testConfigParserFromStringWithExtra() throws JsonProcessingException { + String testString = String.join("\n", + "stringProp: abc", + "intProp: 123", + "invalidProp: 123"); + TestHelper testHelper = ConfigParser.loadConfigFromString(testString, TestHelper.class); + Assert.assertEquals(testHelper.intProp, 123); + Assert.assertEquals(testHelper.boolProp, false); + Assert.assertEquals(testHelper.stringProp, "abc"); + } + +} + diff --git a/zoo/src/test/java/com/intel/analytics/zoo/utils/TestHelper.java b/zoo/src/test/java/com/intel/analytics/zoo/utils/TestHelper.java new file mode 100644 index 00000000000..4d25cb8fde8 --- /dev/null +++ b/zoo/src/test/java/com/intel/analytics/zoo/utils/TestHelper.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 The Analytic Zoo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.zoo.utils; + +public class TestHelper { + String stringProp; + int intProp; + boolean boolProp; + + public void setBoolProp(boolean boolProp) { + this.boolProp = boolProp; + } + + public void setIntProp(int intProp) { + this.intProp = intProp; + } + + public void setStringProp(String stringProp) { + this.stringProp = stringProp; + } +} \ No newline at end of file