diff --git a/src/lava/magma/runtime/message_infrastructure/CMakeLists.txt b/src/lava/magma/runtime/message_infrastructure/CMakeLists.txt index ff1bbbc8f..a1da3b2cf 100644 --- a/src/lava/magma/runtime/message_infrastructure/CMakeLists.txt +++ b/src/lava/magma/runtime/message_infrastructure/CMakeLists.txt @@ -15,6 +15,7 @@ set (MESSAGE_INFRASTRUCTURE_SRCS "message_infrastructure/csrc/multiprocessing.cc" "message_infrastructure/csrc/posix_actor.cc" "message_infrastructure/csrc/shm.cc" + "message_infrastructure/csrc/socket.cc" "message_infrastructure/csrc/shmem_channel.cc" "message_infrastructure/csrc/shmem_port.cc" ) diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/channel_factory.h b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/channel_factory.h index 732b2e1bc..b9b78ca68 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/channel_factory.h +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/channel_factory.h @@ -28,8 +28,10 @@ class ChannelFactory { break; case DDSCHANNEL: break; + case SOCKETCHANNEL: + default: - return GetShmemChannel(size, nbytes, src_name, dst_name); + return GetShmemChannel(channel_type, size, nbytes, src_name, dst_name); } return NULL; } diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/communicator.h b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/communicator.h new file mode 100644 index 000000000..de531426c --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/communicator.h @@ -0,0 +1,26 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef COMMUNICATOR_H_ +#define COMMUNICATOR_H_ + +#include + +namespace message_infrastructure { + +using HandleFn = std::function; + +class SharedCommunicator { + public: + SharedCommunicator() {} + virtual void Start() = 0; + virtual bool Load(HandleFn consume_fn) = 0; + virtual void Store(HandleFn store_fn) = 0; +}; + +using SharedCommunicatorPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // COMMUNICATOR_H_ \ No newline at end of file diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/message_infrastructure_logging.h b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/message_infrastructure_logging.h index 48119841c..a9fe2aa08 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/message_infrastructure_logging.h +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/message_infrastructure_logging.h @@ -11,6 +11,7 @@ #define LOG_LAYER (0) #define DEBUG_MODE (0) #define LOG_SMMP (0) // log for shmemport +#define LOG_SSKP (0) // log for socketport #define LAVA_LOG(_cond, _fmt, ...) { \ if ((_cond)) { \ diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/message_infrastructure_py_wrapper.cc b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/message_infrastructure_py_wrapper.cc index 746a806ad..3df0ffbe5 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/message_infrastructure_py_wrapper.cc +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/message_infrastructure_py_wrapper.cc @@ -74,6 +74,7 @@ PYBIND11_MODULE(MessageInfrastructurePywrapper, m) { .value("SHMEMCHANNEL", SHMEMCHANNEL) .value("RPCCHANNEL", RPCCHANNEL) .value("DDSCHANNEL", DDSCHANNEL) + .value("SOCKETCHANNEL", SOCKETCHANNEL) .export_values(); py::class_> (m, "AbstractTransferPort") .def(py::init<>()); diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shm.h b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shm.h index 176806344..151824513 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shm.h +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shm.h @@ -22,6 +22,7 @@ #include #include "message_infrastructure_logging.h" +#include "communicator.h" namespace message_infrastructure { @@ -30,7 +31,7 @@ namespace message_infrastructure { using HandleFn = std::function; -class SharedMemory { +class SharedMemory : public SharedCommunicator { public: SharedMemory() {} SharedMemory(const size_t &mem_size, const int &shmfd, const int &key); diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_channel.cc b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_channel.cc index b9ad8586b..6a386257e 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_channel.cc +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_channel.cc @@ -8,16 +8,29 @@ namespace message_infrastructure { -ShmemChannel::ShmemChannel(const std::string &src_name, +ShmemChannel::ShmemChannel(const ChannelType &channel_type, + const std::string &src_name, const std::string &dst_name, const size_t &size, const size_t &nbytes) { unsigned long shmem_size = nbytes + sizeof(MetaData); - - shm_ = GetSharedMemManager().AllocChannelSharedMemory(shmem_size); - - send_port_ = std::make_shared(src_name, shm_, size, shmem_size); - recv_port_ = std::make_shared(dst_name, shm_, size, shmem_size); + size_t items_size = size; + switch (channel_type) { + case RPCCHANNEL: + break; + case DDSCHANNEL: + break; + case SOCKETCHANNEL: + items_size = 0; + shm_ = GetSharedSktManager().AllocChannelSharedSocket(shmem_size); + break; + default: + shm_ = GetSharedMemManager().AllocChannelSharedMemory(shmem_size); + } + // shm_ = GetSharedMemManager().AllocChannelSharedMemory(shmem_size); + + send_port_ = std::make_shared(src_name, shm_, items_size, shmem_size); + recv_port_ = std::make_shared(dst_name, shm_, items_size, shmem_size); } AbstractSendPortPtr ShmemChannel::GetSendPort() { @@ -28,11 +41,13 @@ AbstractRecvPortPtr ShmemChannel::GetRecvPort() { return recv_port_; } -std::shared_ptr GetShmemChannel(const size_t &size, +std::shared_ptr GetShmemChannel(const ChannelType &channel_type, + const size_t &size, const size_t &nbytes, const std::string &src_name, const std::string &dst_name) { - return (std::make_shared(src_name, + return (std::make_shared(channel_type, + src_name, dst_name, size, nbytes)); diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_channel.h b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_channel.h index e2e7be153..42afcc22a 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_channel.h +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_channel.h @@ -11,6 +11,7 @@ #include "abstract_channel.h" #include "abstract_port.h" #include "shm.h" +#include "socket.h" #include "shmem_port.h" namespace message_infrastructure { @@ -18,19 +19,21 @@ namespace message_infrastructure { class ShmemChannel : public AbstractChannel { public: ShmemChannel() {} - ShmemChannel(const std::string &src_name, + ShmemChannel(const ChannelType &channel_type, + const std::string &src_name, const std::string &dst_name, const size_t &size, const size_t &nbytes); AbstractSendPortPtr GetSendPort(); AbstractRecvPortPtr GetRecvPort(); private: - SharedMemoryPtr shm_ = nullptr; - ShmemSendPortPtr send_port_ = nullptr; - ShmemRecvPortPtr recv_port_ = nullptr; + SharedCommunicatorPtr shm_ = NULL; + ShmemSendPortPtr send_port_ = NULL; + ShmemRecvPortPtr recv_port_ = NULL; }; -std::shared_ptr GetShmemChannel(const size_t &size, +std::shared_ptr GetShmemChannel(const ChannelType &channel_type, + const size_t &size, const size_t &nbytes, const std::string &src_name, const std::string &dst_name); diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_port.cc b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_port.cc index 87e113e5f..fdf68e0d6 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_port.cc +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_port.cc @@ -118,7 +118,7 @@ ShmemRecvQueue::~ShmemRecvQueue() { } ShmemSendPort::ShmemSendPort(const std::string &name, - SharedMemoryPtr shm, + SharedCommunicatorPtr shm, const size_t &size, const size_t &nbytes) : AbstractSendPort(name, size, nbytes), shm_(shm), done_(false) @@ -146,15 +146,17 @@ void ShmemSendPort::Join() { } ShmemRecvPort::ShmemRecvPort(const std::string &name, - SharedMemoryPtr shm, + SharedCommunicatorPtr shm, const size_t &size, const size_t &nbytes) : AbstractRecvPort(name, size, nbytes), shm_(shm), done_(false) { + if (size_ != 0) queue_ = std::make_shared(name_, size_, nbytes_); } void ShmemRecvPort::Start() { + if (size_ != 0) recv_queue_thread_ = std::make_shared(&message_infrastructure::ShmemRecvPort::QueueRecv, this); } @@ -173,12 +175,41 @@ void ShmemRecvPort::QueueRecv() { } } +char * ShmemRecvPort::NoQueueRecv(){ + if(!done_.load()) { + bool ret = false; + // if (this->queue_->AvailableCount() > 0) { + void *ptr = malloc(nbytes_); + ret = shm_->Load([this, ptr](void* data){ + //this->queue_->Push(data); + std::memcpy(ptr, data, nbytes_); + }); + // } + if (!ret) { + // sleep + // helper::Sleep(); + free(ptr); + return NULL; + } + return (char *)ptr; + } + return NULL; +} + bool ShmemRecvPort::Probe() { return queue_->Probe(); } MetaDataPtr ShmemRecvPort::Recv() { - char *cptr = (char *)queue_->Pop(true); + char *cptr; + if (size_ != 0){ + cptr = (char *)queue_->Pop(true); + }else{ + cptr = NoQueueRecv(); + } + if (cptr == NULL){ + + } MetaDataPtr metadata_res = std::make_shared(); std::memcpy(metadata_res.get(), cptr, sizeof(MetaData)); metadata_res->mdata = (void*)(cptr + sizeof(MetaData)); @@ -188,8 +219,10 @@ MetaDataPtr ShmemRecvPort::Recv() { void ShmemRecvPort::Join() { if (!done_) { done_ = true; + if (size_ != 0){ recv_queue_thread_->join(); queue_->Stop(); + } } } diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_port.h b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_port.h index f6176118d..35debd53e 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_port.h +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/shmem_port.h @@ -21,7 +21,7 @@ using ThreadPtr = std::shared_ptr; class ShmemSendPort final : public AbstractSendPort { public: ShmemSendPort(const std::string &name, - SharedMemoryPtr shm, + SharedCommunicatorPtr shm, const size_t &size, const size_t &nbytes); void Start(); @@ -30,7 +30,7 @@ class ShmemSendPort final : public AbstractSendPort { bool Probe(); private: - SharedMemoryPtr shm_ = nullptr; + SharedCommunicatorPtr shm_ = nullptr; int idx_ = 0; std::atomic_bool done_; ThreadPtr ack_callback_thread_ = nullptr; @@ -68,7 +68,7 @@ using ShmemRecvQueuePtr = std::shared_ptr; class ShmemRecvPort final : public AbstractRecvPort { public: ShmemRecvPort(const std::string &name, - SharedMemoryPtr shm, + SharedCommunicatorPtr shm, const size_t &size, const size_t &nbytes); void Start(); @@ -77,9 +77,10 @@ class ShmemRecvPort final : public AbstractRecvPort { void Join(); MetaDataPtr Peek(); void QueueRecv(); + char *NoQueueRecv(); private: - SharedMemoryPtr shm_ = nullptr; + SharedCommunicatorPtr shm_ = nullptr; int idx_ = 0; std::atomic_bool done_; ShmemRecvQueuePtr queue_ = nullptr; diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/socket.cc b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/socket.cc new file mode 100644 index 000000000..c6dde885d --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/socket.cc @@ -0,0 +1,190 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#include "socket.h" + +namespace message_infrastructure { + +SharedSocket::SharedSocket(const size_t &mem_size, int socket[2], const int &key) { +// shmfd_ = shmfd; + socket_[0] = socket[0]; + socket_[1] = socket[1]; + size_ = mem_size; + req_name_ += std::to_string(key); + ack_name_ += std::to_string(key); +} + +SharedSocket::SharedSocket(const size_t &mem_size, int socket[2]) { + socket_[0] = socket[0]; + socket_[1] = socket[1]; + size_ = mem_size; +} + +void SharedSocket::InitSemaphore() { + req_ = sem_open(req_name_.c_str(), O_CREAT, 0644, 0); + ack_ = sem_open(ack_name_.c_str(), O_CREAT, 0644, 0); +} + +void SharedSocket::Start() { + // RecvPort will post init sem. +} + +void SharedSocket::Store(HandleFn store_fn) { + + long long buffer[(size_+1)/8]; + store_fn((void*)buffer); + char temp; + // sem_wait(ack_); + LAVA_LOG_ERR("prewrite data.\n"); + size_t length = write(socket_[0], (char *)buffer, size_); + LAVA_LOG_ERR("postwrite data.\n"); + // sem_post(req_); + if (length == -1){ + LAVA_LOG_ERR("Write socket failed.\n"); + exit(-1); + } else if (length != size_){ + LAVA_LOG_ERR("Write socket error, expected size: %zd, got size: %zd", size_, length); + exit(-1); + } + LAVA_LOG_ERR("preread ack.\n"); + length = read(socket_[0], &temp, 1); + LAVA_LOG_ERR("postread ack.\n"); + if (length != 1 || temp != 'a'){ + + } +// LAVA_LOG_ERR("Write socket size: %zd.\n", length); +// for(int i=0;i<(size_+1)/8;i++){ +// LAVA_LOG_ERR("Write Socket Buffer: %lld\n", buffer[i]); +// } +} + +bool SharedSocket::Load(HandleFn consume_fn) { + long long buffer[(size_+1)/8]; + bool ret = false; + int val; + char temp = 'a'; + size_t length = 0; + // if (!sem_trywait(req_)) + // { + LAVA_LOG_ERR("preread data.\n"); + length = read(socket_[1], (char *)buffer, size_); + LAVA_LOG_ERR("postread data.\n"); + // consume_fn(MemMap()); + ret = true; + // } + // sem_getvalue(ack_, &val); + // if (val == 0) { + // sem_post(ack_); + // } + if (!ret){ + return ret; + } + if (length < 0) { + ret = false; + LAVA_LOG_ERR("Read socket failed."); + exit(-1); + } else if (size_ != length) { + ret = false; + LAVA_LOG_ERR("Read socket error, expected size: %zd, got size: %zd", size_, length); + exit(-1); + } + consume_fn((void*)buffer); +// LAVA_LOG_ERR("Read socket size: %zd.\n", length); +// for(int i=0;i<(size_+1)/8;i++){ +// LAVA_LOG_ERR("Read Socket Buffer: %lld\n", buffer[i]); +// } + LAVA_LOG_ERR("prewrite ack.\n"); + length = write(socket_[1], &temp, 1); + LAVA_LOG_ERR("postwrite ack.\n"); + if (length != 1){ + + } + return ret; +} + +void SharedSocket::Close() { + sem_close(req_); + sem_close(ack_); +} + +void* SharedSocket::MemMap() { +// return (data_ = mmap(NULL, size_, PROT_READ | PROT_WRITE, MAP_SHARED, shmfd_, 0)); + return NULL; +} + + +// int SharedSocket::GetDataElem(int offset) { +// return static_cast (*(((char*)data_) + offset)); +// } + +SharedSocket::~SharedSocket() { + Close(); + sem_unlink(req_name_.c_str()); + sem_unlink(ack_name_.c_str()); +} + +// RwSharedSocket::RwSharedSocket(const size_t &mem_size, const int &shmfd, const int &key) +// : size_(mem_size), shmfd_(shmfd) +// { +// sem_name_ += std::to_string(key); +// } + +// void RwSharedSocket::InitSemaphore() { +// sem_ = sem_open(sem_name_.c_str(), O_CREAT, 0644, 0); +// } + +// void RwSharedSocket::Start() { +// sem_post(sem_); +// } + +// void RwSharedSocket::Handle(HandleFn handle_fn) { +// sem_wait(sem_); +// handle_fn(GetData()); +// sem_post(sem_); +// } + +// void RwSharedSocket::Close() { +// sem_close(sem_); +// } + +// void* RwSharedSocket::GetData() { +// return (data_ = mmap(NULL, size_, PROT_READ | PROT_WRITE, MAP_SHARED, shmfd_, 0)); +// } + +// RwSharedSocket::~RwSharedSocket() { +// Close(); +// sem_unlink(sem_name_.c_str()); +// } + +void SharedSktManager::DeleteSharedSocket(int &socket) { + if (sockets_.find(socket) != sockets_.end()) { + close(socket); + sockets_.erase(socket); + } else { + LAVA_LOG_WARN(LOG_SSKP,"There is no socket whose fd is %d.\n", socket); + } + // Release specific shared memory +// if (shm_strs_.find(shm_str) != shm_strs_.end()) { +// shm_unlink(shm_str.c_str()); +// shm_strs_.erase(shm_str); +// } else { +// LAVA_LOG_WARN(LOG_SMMP,"There is no shmem whose name is %s.\n", shm_str.c_str()); +// } +} + +SharedSktManager::~SharedSktManager() { + int result = 0; + for (auto it = sockets_.begin(); it != sockets_.end(); it++) { + close(*it); + } + sockets_.clear(); +} + +SharedSktManager SharedSktManager::ssm_; + +SharedSktManager& GetSharedSktManager() { + SharedSktManager &ssm = SharedSktManager::ssm_; + return ssm; +} +} // namespace message_infrastructure diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/socket.h b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/socket.h new file mode 100644 index 000000000..5cffae540 --- /dev/null +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/socket.h @@ -0,0 +1,137 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// See: https://spdx.org/licenses/ + +#ifndef SOCKET_H_ +#define SOCKET_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "message_infrastructure_logging.h" +#include "communicator.h" + +namespace message_infrastructure { + +// #define SHM_FLAG O_RDWR | O_CREAT +// #define SHM_MODE S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP | S_IROTH | S_IWOTH + +// using HandleFn = std::function; + +class SharedSocket : public SharedCommunicator{ + public: + SharedSocket() {} + SharedSocket(const size_t &mem_size, int socket[2], const int &key); + SharedSocket(const size_t &mem_size, int socket[2]); + ~SharedSocket(); + void Start(); + bool Load(HandleFn consume_fn); + void Store(HandleFn store_fn); + void Close(); + void InitSemaphore(); +// int GetDataElem(int offset); + + private: + size_t size_; +// int shmfd_; + int socket_[2]; + std::string req_name_ = "req"; + std::string ack_name_ = "ack"; + sem_t *req_; + sem_t *ack_; + void *data_; + + void* MemMap(); +}; + +// class RwSharedSocket { +// public: +// RwSharedSocket(const size_t &mem_size, const int &shmfd, const int &key); +// ~RwSharedSocket(); +// void InitSemaphore(); +// void Start(); +// void Handle(HandleFn handle_fn); +// void Close(); + +// private: +// size_t size_; +// int shmfd_; +// std::string sem_name_ = "sem"; +// sem_t *sem_; +// void *data_; + +// void *GetData(); +// }; + +using SharedSocketPtr = std::shared_ptr; +// using RwSharedSocketPtr = std::shared_ptr; + +class SharedSktManager { + public: + ~SharedSktManager(); + + template + std::shared_ptr AllocChannelSharedSocket(const size_t &mem_size) { + int socket[2]; + int random = std::rand(); + int err = socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, socket); + if (err == -1){ + LAVA_LOG_ERR("Create shared socket object failed.\n"); + exit(-1); + } + sockets_.insert(socket[0]); + sockets_.insert(socket[1]); + std::shared_ptr sharedSocket = std::make_shared(mem_size, socket, random); + sharedSocket->InitSemaphore(); + return sharedSocket; + + // int random = std::rand(); + // std::string str = shm_str_ + std::to_string(random); + // int shmfd = shm_open(str.c_str(), SHM_FLAG, SHM_MODE); + // if (shmfd == -1) { + // LAVA_LOG_ERR("Create shared memory object failed.\n"); + // exit(-1); + // } + // int err = ftruncate(shmfd, mem_size); + // if (err == -1) { + // LAVA_LOG_ERR("Resize shared memory segment failed.\n"); + // exit(-1); + // } + // shm_strs_.insert(str); + // std::shared_ptr shm = std::make_shared(mem_size, shmfd, random); + // shm->InitSemaphore(); + // return shm; + } + + void DeleteSharedSocket(int &socket); + friend SharedSktManager &GetSharedSktManager(); + + private: + SharedSktManager() { + std::srand(std::time(nullptr)); + } + std::set sockets_; + static SharedSktManager ssm_; +// std::string shm_str_ = "shm"; +}; + +SharedSktManager& GetSharedSktManager(); + +using SharedSktManagerPtr = std::shared_ptr; + +} // namespace message_infrastructure + +#endif // SOCKET_H_ diff --git a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/utils.h b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/utils.h index 11931b70a..6bd6dcfeb 100644 --- a/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/utils.h +++ b/src/lava/magma/runtime/message_infrastructure/message_infrastructure/csrc/utils.h @@ -23,7 +23,8 @@ enum ProcessType { enum ChannelType { SHMEMCHANNEL = 0, RPCCHANNEL = 1, - DDSCHANNEL = 2 + DDSCHANNEL = 2, + SOCKETCHANNEL = 3 }; struct MetaData { diff --git a/src/lava/magma/runtime/message_infrastructure/setenv.sh b/src/lava/magma/runtime/message_infrastructure/setenv.sh index 2a77fe3ee..3264150eb 100644 --- a/src/lava/magma/runtime/message_infrastructure/setenv.sh +++ b/src/lava/magma/runtime/message_infrastructure/setenv.sh @@ -1,3 +1,3 @@ -SCRIPTPATH=$(cd `dirname -- $0` && pwd) +SCRIPTPATH=$(cd `dirname -- $BASH_SOURCE` && pwd) export PYTHONPATH="${SCRIPTPATH}/build:${SCRIPTPATH}:$PYTHONPATH" export LD_LIBRARY_PATH="${SCRIPTPATH}/build:${SCRIPTPATH}:$LD_LIBRARY_PATH" diff --git a/src/lava/magma/runtime/message_infrastructure/test/test_channel.py b/src/lava/magma/runtime/message_infrastructure/test/test_channel.py index 8cb67a2c4..ed9d145ff 100644 --- a/src/lava/magma/runtime/message_infrastructure/test/test_channel.py +++ b/src/lava/magma/runtime/message_infrastructure/test/test_channel.py @@ -6,6 +6,7 @@ import unittest from functools import partial import time +from multiprocessing import Process from message_infrastructure.multiprocessing import MultiProcessing @@ -115,6 +116,95 @@ def test_single_process_shmemchannel(self): send_port.join() recv_port.join() +class TestSocketChannel(unittest.TestCase): + + def test_socketchannel(self): + mp = MultiProcessing() + mp.start() + size = 5 + predata = prepare_data() + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + name = 'test_socket_channel' + + socket_channel = Channel( + ChannelBackend.SOCKETCHANNEL, + size, + nbytes, + name, + name) + + send_port = socket_channel.src_port + recv_port = socket_channel.dst_port + + recv_port_fn = partial(recv_proc, port=recv_port) + send_port_fn = partial(send_proc, port=send_port) + + builder1 = Builder() + builder2 = Builder() + mp.build_actor(recv_port_fn, builder1) + mp.build_actor(send_port_fn, builder2) + + time.sleep(2) + mp.stop(True) + + # def test_single_process_socketchannel(self): + # size = 5 + # predata = prepare_data() + # nbytes = np.prod(predata.shape) * predata.dtype.itemsize + # name = 'test_single_process_socket_channel' + + # socket_channel = Channel( + # ChannelBackend.SOCKETCHANNEL, + # size, + # nbytes, + # name, + # name) + + # send_port = socket_channel.src_port + # recv_port = socket_channel.dst_port + + # send_port.start() + # recv_port.start() + + # send_port.send(predata) + # resdata = recv_port.recv() + + # if not np.array_equal(resdata, predata): + # raise AssertionError() + + # send_port.join() + # recv_port.join() + + def test_multi_process_socket_channel(self): + size = 5 + predata = prepare_data() + nbytes = np.prod(predata.shape) * predata.dtype.itemsize + name = 'test_single_process_socket_channel' + + socket_channel = Channel( + ChannelBackend.SOCKETCHANNEL, + size, + nbytes, + name, + name) + + send_port = socket_channel.src_port + recv_port = socket_channel.dst_port + + send_port.start() + recv_port.start() + t = Process(target=send_port.send, args=(predata,)) + t.start() + # send_port.send(predata) + resdata = recv_port.recv() + + if not np.array_equal(resdata, predata): + raise AssertionError() + + send_port.join() + recv_port.join() + t.join() + if __name__ == "__main__": unittest.main()