Skip to content

Commit

Permalink
VPN snooze mode (#924)
Browse files Browse the repository at this point in the history
Required:

Task/Issue URL: https://app.asana.com/0/72649045549333/1207974416599035/f
iOS PR: duckduckgo/iOS#3184
macOS PR: duckduckgo/macos-browser#3085
What kind of version bump will this require?: Major

Description:

This PR adds VPN snooze mode support.
  • Loading branch information
samsymons authored Aug 12, 2024
1 parent d679798 commit c3ae186
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 17 deletions.
28 changes: 26 additions & 2 deletions Sources/NetworkProtection/ExtensionMessage/ExtensionMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public enum ExtensionMessage: RawRepresentable {
case simulateTunnelMemoryOveruse
case simulateConnectionInterruption
case getDataVolume
case startSnooze
case cancelSnooze
}

// This is actually an improved way to send messages.
Expand All @@ -69,6 +71,8 @@ public enum ExtensionMessage: RawRepresentable {
case simulateTunnelMemoryOveruse
case simulateConnectionInterruption
case getDataVolume
case startSnooze(TimeInterval)
case cancelSnooze

public init?(rawValue data: Data) {
let name = data.first.flatMap(Name.init(rawValue:))
Expand Down Expand Up @@ -138,6 +142,18 @@ public enum ExtensionMessage: RawRepresentable {
case .getDataVolume:
self = .getDataVolume

case .startSnooze:
guard data.count == MemoryLayout<UInt>.size + 1 else { return nil }
let uintValue = data.withUnsafeBytes {
$0.loadUnaligned(fromByteOffset: 1, as: UInt.self)
}
let snoozeDuration = TimeInterval(uintValue.littleEndian)

self = .startSnooze(snoozeDuration)

case .cancelSnooze:
self = .cancelSnooze

case .none:
assertionFailure("Invalid data")
return nil
Expand Down Expand Up @@ -165,6 +181,8 @@ public enum ExtensionMessage: RawRepresentable {
case .simulateTunnelMemoryOveruse: return .simulateTunnelMemoryOveruse
case .simulateConnectionInterruption: return .simulateConnectionInterruption
case .getDataVolume: return .getDataVolume
case .startSnooze: return .startSnooze
case .cancelSnooze: return .cancelSnooze
}
}

Expand Down Expand Up @@ -199,7 +217,12 @@ public enum ExtensionMessage: RawRepresentable {
assertionFailure("could not encode routes: \(error)")
}
}

case .startSnooze(let interval):
encoder = { data in
withUnsafeBytes(of: UInt(interval).littleEndian) { buffer in
data.append(Data(buffer))
}
}
case .setSelectedServer(.none),
.setKeyValidity(.none),
.resetAllState,
Expand All @@ -214,7 +237,8 @@ public enum ExtensionMessage: RawRepresentable {
.simulateTunnelFatalError,
.simulateTunnelMemoryOveruse,
.simulateConnectionInterruption,
.getDataVolume: break
.getDataVolume,
.cancelSnooze: break

}

Expand Down
117 changes: 115 additions & 2 deletions Sources/NetworkProtection/PacketTunnelProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,12 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}

if case .connected = connectionStatus {
self.notificationsPresenter.showConnectedNotification(serverLocation: lastSelectedServerInfo?.serverLocation)
self.notificationsPresenter.showConnectedNotification(
serverLocation: lastSelectedServerInfo?.serverLocation,
snoozeEnded: snoozeJustEnded
)

snoozeJustEnded = false
}

handleConnectionStatusChange(old: oldValue, new: connectionStatus)
Expand Down Expand Up @@ -410,6 +415,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
private let tunnelHealth: NetworkProtectionTunnelHealthStore
private let controllerErrorStore: NetworkProtectionTunnelErrorStore
private let knownFailureStore: NetworkProtectionKnownFailureStore
private let snoozeTimingStore: NetworkProtectionSnoozeTimingStore

// MARK: - Cancellables

Expand All @@ -428,6 +434,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
tunnelHealthStore: NetworkProtectionTunnelHealthStore,
controllerErrorStore: NetworkProtectionTunnelErrorStore,
knownFailureStore: NetworkProtectionKnownFailureStore = NetworkProtectionKnownFailureStore(),
snoozeTimingStore: NetworkProtectionSnoozeTimingStore,
keychainType: KeychainType,
tokenStore: NetworkProtectionTokenStore,
debugEvents: EventMapping<NetworkProtectionError>?,
Expand All @@ -446,6 +453,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
self.tunnelHealth = tunnelHealthStore
self.controllerErrorStore = controllerErrorStore
self.knownFailureStore = knownFailureStore
self.snoozeTimingStore = snoozeTimingStore
self.settings = settings
self.defaults = defaults
self.isSubscriptionEnabled = isSubscriptionEnabled
Expand Down Expand Up @@ -653,6 +661,9 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
let startupOptions = StartupOptions(options: options ?? [:])
os_log("Starting tunnel with options: %{public}s", log: .networkProtection, startupOptions.description)

// Reset snooze if the VPN is restarting.
self.snoozeTimingStore.reset()

do {
try load(options: startupOptions)
try loadVendorOptions(from: tunnelProviderProtocol)
Expand Down Expand Up @@ -810,6 +821,11 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
providerEvents.fire(.tunnelStopAttempt(.failure(error)))
}

if case .userInitiated = reason {
// If the user shut down the VPN deliberately, end snooze mode early.
self.snoozeTimingStore.reset()
}

if case .superceded = reason {
self.notificationsPresenter.showSupersededNotification()
}
Expand Down Expand Up @@ -1046,6 +1062,10 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
simulateConnectionInterruption(completionHandler: completionHandler)
case .getDataVolume:
getDataVolume(completionHandler: completionHandler)
case .startSnooze(let duration):
startSnooze(duration, completionHandler: completionHandler)
case .cancelSnooze:
cancelSnooze(completionHandler: completionHandler)
}
}

Expand Down Expand Up @@ -1363,6 +1383,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
case onDemand
case reconnected
case wake
case snoozeEnded
}

/// Called when the adapter reports that the tunnel was successfully started.
Expand Down Expand Up @@ -1641,7 +1662,6 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
@MainActor
public override func sleep() async {
os_log("Sleep", log: .networkProtectionSleepLog)

await stopMonitors()
}

Expand Down Expand Up @@ -1670,6 +1690,99 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}
}
}

// MARK: - Snooze

private func startSnooze(_ duration: TimeInterval, completionHandler: ((Data?) -> Void)? = nil) {
Task {
await startSnooze(duration: duration)
completionHandler?(nil)
}
}

private func cancelSnooze(completionHandler: ((Data?) -> Void)? = nil) {
Task {
await cancelSnooze()
completionHandler?(nil)
}
}

private var snoozeTimerTask: Task<Never, Error>? {
willSet {
snoozeTimerTask?.cancel()
}
}

private var snoozeRequestProcessing: Bool = false
private var snoozeJustEnded: Bool = false

@MainActor
private func startSnooze(duration: TimeInterval) async {
if snoozeRequestProcessing {
os_log("Rejecting start snooze request due to existing request processing", log: .networkProtection)
return
}

snoozeRequestProcessing = true
os_log("Starting snooze mode with duration: %{public}d", log: .networkProtection, duration)

await stopMonitors()

self.adapter.snooze { [weak self] error in
guard let self else {
assertionFailure("Failed to get strong self")
return
}

if error == nil {
self.connectionStatus = .snoozing
self.snoozeTimingStore.activeTiming = .init(startDate: Date(), duration: duration)
self.notificationsPresenter.showSnoozingNotification(duration: duration)

snoozeTimerTask = Task.periodic(interval: .seconds(1)) { [weak self] in
guard let self else { return }

if self.snoozeTimingStore.hasExpired {
Task.detached {
os_log("Snooze mode timer expired, canceling snooze now...", log: .networkProtection)
await self.cancelSnooze()
}
}
}
} else {
self.snoozeTimingStore.reset()
}

self.snoozeRequestProcessing = false
}
}

private func cancelSnooze() async {
if snoozeRequestProcessing {
os_log("Rejecting cancel snooze request due to existing request processing", log: .networkProtection)
return
}

snoozeRequestProcessing = true
defer {
snoozeRequestProcessing = false
}

snoozeTimerTask?.cancel()
snoozeTimerTask = nil

guard await connectionStatus == .snoozing, snoozeTimingStore.activeTiming != nil else {
os_log("Failed to cancel snooze mode as it was not active", log: .networkProtection, type: .error)
return
}

os_log("Canceling snooze mode", log: .networkProtection)

snoozeJustEnded = true
try? await startTunnel(onDemand: false)
snoozeTimingStore.reset()
}

}

extension WireGuardAdapterError: LocalizedError, CustomDebugStringConvertible {
Expand Down
2 changes: 1 addition & 1 deletion Sources/NetworkProtection/StartupOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct StartupOptions {
///
case manualByMainApp

/// Started up manually from a Syste provided source: it can be the VPN menu, a CLI command
/// Started up manually from a system-provided source: it can be the VPN menu, a CLI command
/// or the list of VPNs in System Settings.
///
case manualByTheSystem
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public class ConnectionServerInfoObserverThroughSession: ConnectionServerInfoObs

private func updateServerInfo(session: NETunnelProviderSession) async {
guard session.status == .connected else {
subject.send(NetworkProtectionStatusServerInfo.unknown)
return
}

Expand Down
3 changes: 3 additions & 0 deletions Sources/NetworkProtection/Status/ConnectionStatus.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public enum ConnectionStatus: Codable, Equatable {
case connected(connectedDate: Date)
case connecting
case reasserting
case snoozing

public static var `default`: ConnectionStatus = .disconnected

Expand All @@ -42,6 +43,8 @@ public enum ConnectionStatus: Codable, Equatable {
return "connecting"
case .reasserting:
return "reasserting"
case .snoozing:
return "snoozing"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class ConnectionStatusObserverThroughSession: ConnectionStatusObserver {

// MARK: - Notifications
private let notificationCenter: NotificationCenter
private let platformSnoozeTimingStore: NetworkProtectionSnoozeTimingStore
private let platformNotificationCenter: NotificationCenter
private let platformDidWakeNotification: Notification.Name
private var cancellables = Set<AnyCancellable>()
Expand All @@ -49,11 +50,13 @@ public class ConnectionStatusObserverThroughSession: ConnectionStatusObserver {

public init(tunnelSessionProvider: TunnelSessionProvider,
notificationCenter: NotificationCenter = .default,
platformSnoozeTimingStore: NetworkProtectionSnoozeTimingStore,
platformNotificationCenter: NotificationCenter,
platformDidWakeNotification: Notification.Name,
log: OSLog = .networkProtection) {

self.notificationCenter = notificationCenter
self.platformSnoozeTimingStore = platformSnoozeTimingStore
self.platformNotificationCenter = platformNotificationCenter
self.platformDidWakeNotification = platformDidWakeNotification
self.tunnelSessionProvider = tunnelSessionProvider
Expand All @@ -74,8 +77,12 @@ public class ConnectionStatusObserverThroughSession: ConnectionStatusObserver {
self?.handleStatusChangeNotification(notification)
}.store(in: &cancellables)

notificationCenter.publisher(for: .VPNSnoozeRefreshed).sink { [weak self] notification in
self?.handleStatusRefreshNotification(notification)
}.store(in: &cancellables)

platformNotificationCenter.publisher(for: platformDidWakeNotification).sink { [weak self] notification in
self?.handleDidWake(notification)
self?.handleStatusRefreshNotification(notification)
}.store(in: &cancellables)
}

Expand All @@ -89,7 +96,7 @@ public class ConnectionStatusObserverThroughSession: ConnectionStatusObserver {

// MARK: - Handling Notifications

private func handleDidWake(_ notification: Notification) {
private func handleStatusRefreshNotification(_ notification: Notification) {
Task {
guard let session = await tunnelSessionProvider.activeSession() else {
return
Expand Down Expand Up @@ -118,7 +125,7 @@ public class ConnectionStatusObserverThroughSession: ConnectionStatusObserver {
private func connectedDate(from session: NETunnelProviderSession) -> Date {
// In theory when the connection has been established, the date should be set. But in a worst-case
// scenario where for some reason the date is missing, we're going to just use Date() as the connection
// has just started and it's a decent aproximation.
// has just started and it's a decent approximation.
session.connectedDate ?? Date()
}

Expand All @@ -128,8 +135,12 @@ public class ConnectionStatusObserverThroughSession: ConnectionStatusObserver {

switch internalStatus {
case .connected:
let connectedDate = connectedDate(from: session)
status = .connected(connectedDate: connectedDate)
if platformSnoozeTimingStore.activeTiming != nil {
status = .snoozing
} else {
let connectedDate = connectedDate(from: session)
status = .connected(connectedDate: connectedDate)
}
case .connecting:
status = .connecting
case .reasserting:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ import Foundation
public protocol NetworkProtectionNotificationsPresenter {

/// Present a "connected" notification to the user.
func showConnectedNotification(serverLocation: String?)
func showConnectedNotification(serverLocation: String?, snoozeEnded: Bool)

/// Present a "reconnecting" notification to the user.
func showReconnectingNotification()

/// Present a "connection failure" notification to the user.
func showConnectionFailureNotification()

/// Present a "snoozing" notification to the user.
func showSnoozingNotification(duration: TimeInterval)

/// Present a "Superseded by another App" notification to the user.
func showSupersededNotification()

Expand All @@ -40,4 +43,5 @@ public protocol NetworkProtectionNotificationsPresenter {

/// Present a "expired subscription" notification to the user.
func showEntitlementNotification()

}
Loading

0 comments on commit c3ae186

Please sign in to comment.