diff --git a/internal/service/network/create.go b/internal/service/network/create.go index d1bfc3cb..dea8252a 100644 --- a/internal/service/network/create.go +++ b/internal/service/network/create.go @@ -88,6 +88,13 @@ func (s *service) Create(ctx context.Context, request types.NetworkCreateRequest return types.NetworkCreateResponse{}, err } + // Ensure thread-safety for network operations using a per-network mutex. + // Operations on different network IDs can proceed concurrently. + netMu := s.ensureLock(request.Name) + + netMu.Lock() + defer netMu.Unlock() + // Create network net, err := s.netClient.CreateNetwork(options) if err != nil && strings.Contains(err.Error(), "unsupported cni driver") { diff --git a/internal/service/network/network.go b/internal/service/network/network.go index 8f29f54a..fd7cc866 100644 --- a/internal/service/network/network.go +++ b/internal/service/network/network.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "regexp" + "sync" "github.com/containerd/containerd" "github.com/containerd/nerdctl/pkg/netutil" @@ -25,16 +26,19 @@ var ( ) type service struct { - client backend.ContainerdClient - netClient backend.NerdctlNetworkSvc - logger flog.Logger + client backend.ContainerdClient + netClient backend.NerdctlNetworkSvc + logger flog.Logger + mu sync.RWMutex + networkMutexes map[string]*sync.Mutex } func NewService(client backend.ContainerdClient, netClient backend.NerdctlNetworkSvc, logger flog.Logger) network.Service { return &service{ - client: client, - netClient: netClient, - logger: logger, + client: client, + netClient: netClient, + logger: logger, + networkMutexes: make(map[string]*sync.Mutex), } } @@ -97,3 +101,31 @@ func (s *service) getContainer(ctx context.Context, containerId string) (contain return searchResult[0], nil } + +func (s *service) ensureLock(networkName string) *sync.Mutex { + s.mu.RLock() + if netMu, exists := s.networkMutexes[networkName]; exists { + s.mu.RUnlock() + return netMu + } + s.mu.RUnlock() + + s.mu.Lock() + defer s.mu.Unlock() + + // Double-check in case another goroutine created it while we were waiting for the write lock + if netMu, exists := s.networkMutexes[networkName]; exists { + return netMu + } + + netMu := &sync.Mutex{} + s.networkMutexes[networkName] = netMu + return netMu +} + +func (s *service) clearLock(networkName string) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.networkMutexes, networkName) +} diff --git a/internal/service/network/remove.go b/internal/service/network/remove.go index 2d557b7b..78ad4a80 100644 --- a/internal/service/network/remove.go +++ b/internal/service/network/remove.go @@ -32,12 +32,27 @@ func (s *service) Remove(ctx context.Context, networkId string) error { return errdefs.NewForbidden(fmt.Errorf("%s is a pre-defined network and cannot be removed", networkId)) } + // Ensure thread-safety for network operations using a per-network mutex. + // RemoveNetwork and CreateNetwork operations on the same network ID are mutually exclusive. + // Operations on different network IDs can proceed concurrently. + netMu := s.ensureLock(net.Name) + + netMu.Lock() + defer netMu.Unlock() + // Perform additional workflow based on the assigned network labels if err := s.handleNetworkLabels(net); err != nil { return fmt.Errorf("failed to handle nerdctl label: %w", err) } - return s.netClient.RemoveNetwork(net) + err = s.netClient.RemoveNetwork(net) + if err != nil { + return fmt.Errorf("failed to remove network: %w", err) + } + + // Clear the lock if remove was successful + s.clearLock(net.Name) + return nil } func (s *service) handleNetworkLabels(net *netutil.NetworkConfig) error {