From ba70fd8cc17db42b6a129178f04eb5dd15c3ba01 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 22 Aug 2024 15:22:09 +0000 Subject: [PATCH] Self-fix feedback --- Dockerfile | 1 + cmd/gluetun/main.go | 2 +- internal/command/split.go | 146 ++++++++++++++++++ internal/command/split_test.go | 111 +++++++++++++ .../configuration/settings/portforward.go | 19 ++- internal/portforward/interfaces.go | 6 + internal/portforward/loop.go | 8 +- internal/portforward/service/command.go | 53 +++++-- internal/portforward/service/command_test.go | 28 ++++ internal/portforward/service/interfaces.go | 6 + .../service/mocks_generate_test.go | 3 + internal/portforward/service/mocks_test.go | 82 ++++++++++ internal/portforward/service/service.go | 4 +- internal/portforward/service/settings.go | 6 +- internal/portforward/service/start.go | 16 +- 15 files changed, 452 insertions(+), 39 deletions(-) create mode 100644 internal/command/split.go create mode 100644 internal/command/split_test.go create mode 100644 internal/portforward/service/command_test.go create mode 100644 internal/portforward/service/mocks_generate_test.go create mode 100644 internal/portforward/service/mocks_test.go diff --git a/Dockerfile b/Dockerfile index e6a0bb071..08e83df1b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -125,6 +125,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ VPN_PORT_FORWARDING_STATUS_FILE="/tmp/gluetun/forwarded_port" \ VPN_PORT_FORWARDING_USERNAME= \ VPN_PORT_FORWARDING_PASSWORD= \ + VPN_PORT_FORWARDING_UP_COMMAND= \ # # Cyberghost only: OPENVPN_CERT= \ OPENVPN_KEY= \ diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 8574f97eb..c5b471264 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -380,7 +380,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, portForwardLogger := logger.New(log.SetComponent("port forwarding")) portForwardLooper := portforward.NewLoop(allSettings.VPN.Provider.PortForwarding, - routingConf, httpClient, firewallConf, portForwardLogger, puid, pgid) + routingConf, httpClient, firewallConf, portForwardLogger, cmder, puid, pgid) portForwardRunError, err := portForwardLooper.Start(ctx) if err != nil { return fmt.Errorf("starting port forwarding loop: %w", err) diff --git a/internal/command/split.go b/internal/command/split.go new file mode 100644 index 000000000..1bbdf0c3f --- /dev/null +++ b/internal/command/split.go @@ -0,0 +1,146 @@ +package command + +import ( + "bytes" + "errors" + "fmt" + "strings" + "unicode/utf8" +) + +var ( + ErrCommandEmpty = errors.New("command is empty") + ErrSingleQuoteUnterminated = errors.New("unterminated single-quoted string") + ErrDoubleQuoteUnterminated = errors.New("unterminated double-quoted string") + ErrEscapeUnterminated = errors.New("unterminated backslash-escape") +) + +// Split splits a command string into a slice of arguments. +// This is especially important for commands such as: +// /bin/sh -c "echo hello" +// which should be split into: ["/bin/sh", "-c", "echo hello"] +// It supports backslash-escapes, single-quotes and double-quotes. +// It does not support: +// - the $" quoting style. +// - expansion (brace, shell or pathname). +func Split(command string) (words []string, err error) { + if command == "" { + return nil, fmt.Errorf("%w", ErrCommandEmpty) + } + + const bufferSize = 1024 + buffer := bytes.NewBuffer(make([]byte, bufferSize)) + + startIndex := 0 + + for startIndex < len(command) { + // skip any split characters at the start + character, runeSize := utf8.DecodeRuneInString(command[startIndex:]) + switch { + case strings.ContainsRune(" \n\t", character): + startIndex += runeSize + case character == '\\': + // Look ahead to eventually skip an escaped newline + if command[startIndex+runeSize:] == "" { + return nil, fmt.Errorf("%w: %q", ErrEscapeUnterminated, command) + } + character, runeSize := utf8.DecodeRuneInString(command[startIndex+runeSize:]) + if character == '\n' { + startIndex += runeSize + runeSize // backslash and newline + } + default: + var word string + buffer.Reset() + word, startIndex, err = splitWord(command, startIndex, buffer) + if err != nil { + return nil, fmt.Errorf("splitting word in %q: %w", command, err) + } + words = append(words, word) + } + } + return words, nil +} + +// WARNING: buffer must be cleared before calling this function. +func splitWord(input string, startIndex int, buffer *bytes.Buffer) ( + word string, newStartIndex int, err error) { + cursor := startIndex + for cursor < len(input) { + character, runeLength := utf8.DecodeRuneInString(input[cursor:]) + cursor += runeLength + if character == '"' || + character == '\'' || + character == '\\' || + character == ' ' || + character == '\n' || + character == '\t' { + buffer.WriteString(input[startIndex : cursor-runeLength]) + } + + switch { + case strings.ContainsRune(" \n\t", character): // spacing character + return buffer.String(), cursor, nil + case character == '"': + return handleDoubleQuoted(input, cursor, buffer) + case character == '\'': + return handleSingleQuoted(input, cursor, buffer) + case character == '\\': + return handleEscaped(input, cursor, buffer) + } + } + + buffer.WriteString(input[startIndex:]) + return buffer.String(), len(input), nil +} + +func handleDoubleQuoted(input string, startIndex int, buffer *bytes.Buffer) ( + word string, newStartIndex int, err error) { + cursor := startIndex + for cursor < len(input) { + nextCharacter, nextRuneLength := utf8.DecodeRuneInString(input[cursor:]) + cursor += nextRuneLength + switch nextCharacter { + case '"': // end of the double quoted string + buffer.WriteString(input[startIndex : cursor-nextRuneLength]) + return splitWord(input, cursor, buffer) + case '\\': // escaped character + escapedCharacter, escapedRuneLength := utf8.DecodeRuneInString(input[cursor:]) + cursor += escapedRuneLength + if !strings.ContainsRune("$`\"\n\\", escapedCharacter) { + break + } + buffer.WriteString(input[startIndex : cursor-nextRuneLength-escapedRuneLength]) + if escapedCharacter != '\n' { + // skip backslash entirely for the newline character + buffer.WriteRune(escapedCharacter) + } + startIndex = cursor + } + } + return "", 0, fmt.Errorf("%w", ErrDoubleQuoteUnterminated) +} + +func handleSingleQuoted(input string, startIndex int, buffer *bytes.Buffer) ( + word string, newStartIndex int, err error) { + closingQuoteIndex := strings.IndexRune(input[startIndex:], '\'') + if closingQuoteIndex == -1 { + return "", 0, fmt.Errorf("%w", ErrSingleQuoteUnterminated) + } + buffer.WriteString(input[startIndex : startIndex+closingQuoteIndex]) + const singleQuoteRuneLength = 1 + startIndex += closingQuoteIndex + singleQuoteRuneLength + return splitWord(input, startIndex, buffer) +} + +func handleEscaped(input string, startIndex int, buffer *bytes.Buffer) ( + word string, newStartIndex int, err error) { + if input[startIndex:] == "" { + return "", 0, fmt.Errorf("%w", ErrEscapeUnterminated) + } + character, runeLength := utf8.DecodeRuneInString(input[startIndex:]) + if character != '\n' { // backslash-escaped newline is ignored + buffer.WriteString(input[startIndex : startIndex+runeLength]) + } + startIndex += runeLength + return splitWord(input, startIndex, buffer) +} diff --git a/internal/command/split_test.go b/internal/command/split_test.go new file mode 100644 index 000000000..3c2108c33 --- /dev/null +++ b/internal/command/split_test.go @@ -0,0 +1,111 @@ +package command + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Split(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + command string + words []string + errWrapped error + errMessage string + }{ + "empty": { + command: "", + errWrapped: ErrCommandEmpty, + errMessage: "command is empty", + }, + "concrete_sh_command": { + command: `/bin/sh -c "echo 123"`, + words: []string{"/bin/sh", "-c", "echo 123"}, + }, + "single_word": { + command: "word1", + words: []string{"word1"}, + }, + "two_words_single_space": { + command: "word1 word2", + words: []string{"word1", "word2"}, + }, + "two_words_multiple_space": { + command: "word1 word2", + words: []string{"word1", "word2"}, + }, + "two_words_no_expansion": { + command: "word1* word2?", + words: []string{"word1*", "word2?"}, + }, + "escaped_single quote": { + command: "ain\\'t good", + words: []string{"ain't", "good"}, + }, + "escaped_single_quote_all_single_quoted": { + command: "'ain'\\''t good'", + words: []string{"ain't good"}, + }, + "empty_single_quoted": { + command: "word1 '' word2", + words: []string{"word1", "", "word2"}, + }, + "escaped_newline": { + command: "word1\\\nword2", + words: []string{"word1word2"}, + }, + "quoted_newline": { + command: "text \"with\na\" quoted newline", + words: []string{"text", "with\na", "quoted", "newline"}, + }, + "quoted_escaped_newline": { + command: "\"word1\\d\\\\\\\" word2\\\nword3 word4\"", + words: []string{"word1\\d\\\" word2word3 word4"}, + }, + "escaped_separated_newline": { + command: "word1 \\\n word2", + words: []string{"word1", "word2"}, + }, + "double_quotes_no_spacing": { + command: "word1\"word2\"word3", + words: []string{"word1word2word3"}, + }, + "unterminated_single_quote": { + command: "'abc'\\''def", + errWrapped: ErrSingleQuoteUnterminated, + errMessage: `splitting word in "'abc'\\''def": unterminated single-quoted string`, + }, + "unterminated_double_quote": { + command: "\"abc'def", + errWrapped: ErrDoubleQuoteUnterminated, + errMessage: `splitting word in "\"abc'def": unterminated double-quoted string`, + }, + "unterminated_escape": { + command: "abc\\", + errWrapped: ErrEscapeUnterminated, + errMessage: `splitting word in "abc\\": unterminated backslash-escape`, + }, + "unterminated_escape_only": { + command: " \\", + errWrapped: ErrEscapeUnterminated, + errMessage: `unterminated backslash-escape: " \\"`, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + words, err := Split(testCase.command) + + assert.Equal(t, testCase.words, words) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/internal/configuration/settings/portforward.go b/internal/configuration/settings/portforward.go index 87b33ced1..dfdcec256 100644 --- a/internal/configuration/settings/portforward.go +++ b/internal/configuration/settings/portforward.go @@ -29,11 +29,10 @@ type PortForwarding struct { // to write to a file. It cannot be nil for the // internal state Filepath *string `json:"status_file_path"` - // Command is the port forwarding status command - // to use. It can be the empty string to indicate not - // to run a command. It cannot be nil for the - // internal state - Command *string `json:"status_command"` + // UpCommand is the command to use when the port forwarding is up. + // It can be the empty string to indicate not to run a command. + // It cannot be nil in the internal state. + UpCommand *string `json:"up_command"` // ListeningPort is the port traffic would be redirected to from the // forwarded port. The redirection is disabled if it is set to 0, which // is its default as well. @@ -89,7 +88,7 @@ func (p *PortForwarding) Copy() (copied PortForwarding) { Enabled: gosettings.CopyPointer(p.Enabled), Provider: gosettings.CopyPointer(p.Provider), Filepath: gosettings.CopyPointer(p.Filepath), - Command: gosettings.CopyPointer(p.Command), + UpCommand: gosettings.CopyPointer(p.UpCommand), ListeningPort: gosettings.CopyPointer(p.ListeningPort), Username: p.Username, Password: p.Password, @@ -100,7 +99,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) { p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled) p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider) p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath) - p.Command = gosettings.OverrideWithPointer(p.Command, other.Command) + p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand) p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort) p.Username = gosettings.OverrideWithComparable(p.Username, other.Username) p.Password = gosettings.OverrideWithComparable(p.Password, other.Password) @@ -110,7 +109,7 @@ func (p *PortForwarding) setDefaults() { p.Enabled = gosettings.DefaultPointer(p.Enabled, false) p.Provider = gosettings.DefaultPointer(p.Provider, "") p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port") - p.Command = gosettings.DefaultPointer(p.Command, "") + p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "") p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0) } @@ -143,7 +142,7 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) { } node.Appendf("Forwarded port file path: %s", filepath) - command := *p.Command + command := *p.UpCommand if command != "" { node.Appendf("Forwarded port command: %s", command) } @@ -176,7 +175,7 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) { "PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE", )) - p.Command = r.Get("VPN_PORT_FORWARDING_UP_COMMAND", + p.UpCommand = r.Get("VPN_PORT_FORWARDING_UP_COMMAND", reader.ForceLowercase(false)) p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT") diff --git a/internal/portforward/interfaces.go b/internal/portforward/interfaces.go index c3c610ae7..fb442d5eb 100644 --- a/internal/portforward/interfaces.go +++ b/internal/portforward/interfaces.go @@ -3,6 +3,7 @@ package portforward import ( "context" "net/netip" + "os/exec" ) type Service interface { @@ -29,3 +30,8 @@ type Logger interface { Warn(s string) Error(s string) } + +type Cmder interface { + Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string, + waitError <-chan error, startErr error) +} diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go index 1fe9fa08a..cc9e0a157 100644 --- a/internal/portforward/loop.go +++ b/internal/portforward/loop.go @@ -20,6 +20,7 @@ type Loop struct { client *http.Client portAllower PortAllower logger Logger + cmder Cmder // Fixed parameters uid, gid int // Internal channels and locks @@ -34,7 +35,7 @@ type Loop struct { func NewLoop(settings settings.PortForwarding, routing Routing, client *http.Client, portAllower PortAllower, - logger Logger, uid, gid int, + logger Logger, cmder Cmder, uid, gid int, ) *Loop { return &Loop{ settings: Settings{ @@ -42,7 +43,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing, Service: service.Settings{ Enabled: settings.Enabled, Filepath: *settings.Filepath, - Command: *settings.Command, + UpCommand: *settings.UpCommand, ListeningPort: *settings.ListeningPort, }, }, @@ -50,6 +51,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing, client: client, portAllower: portAllower, logger: logger, + cmder: cmder, uid: uid, gid: gid, } @@ -116,7 +118,7 @@ func (l *Loop) run(runCtx context.Context, runDone chan<- struct{}, *serviceSettings.Enabled = *serviceSettings.Enabled && *l.settings.VPNIsUp l.service = service.New(serviceSettings, l.routing, l.client, - l.portAllower, l.logger, l.uid, l.gid) + l.portAllower, l.logger, l.cmder, l.uid, l.gid) var err error serviceRunError, err = l.service.Start(runCtx) diff --git a/internal/portforward/service/command.go b/internal/portforward/service/command.go index f721a4d2b..30a6dcf4a 100644 --- a/internal/portforward/service/command.go +++ b/internal/portforward/service/command.go @@ -3,30 +3,57 @@ package service import ( "context" "fmt" - "os" "os/exec" "strings" + + "github.com/qdm12/gluetun/internal/command" ) -func (s *Service) runUpCommand(ctx context.Context, ports []uint16) (err error) { - // run command replacing {{PORTS}} with the ports (space separated) +func runUpCommand(ctx context.Context, cmder Cmder, logger Logger, + commandTemplate string, ports []uint16, +) (err error) { portStrings := make([]string, len(ports)) for i, port := range ports { portStrings[i] = fmt.Sprint(int(port)) } portsString := strings.Join(portStrings, ",") + commandString := strings.ReplaceAll(commandTemplate, "{{PORTS}}", portsString) + args, err := command.Split(commandString) + if err != nil { + return fmt.Errorf("parsing command: %w", err) + } - rawCommand := strings.ReplaceAll(s.settings.Command, "{{PORTS}}", portsString) - s.logger.Info("running port forward command " + rawCommand) - command := strings.Split(rawCommand, " ") - // it is a user input and we trust it so we can ignore the gosec warning - cmd := exec.CommandContext(ctx, command[0], command[1:]...) // #nosec G204 - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - err = cmd.Run() + cmd := exec.CommandContext(ctx, args[0], args[1:]...) // #nosec G204 + stdout, stderr, waitError, err := cmder.Start(cmd) if err != nil { - return fmt.Errorf("running command: %w", err) + return err } - return nil + streamCtx, streamCancel := context.WithCancel(context.Background()) + streamDone := make(chan struct{}) + go streamLines(streamCtx, streamDone, logger, stdout, stderr) + + err = <-waitError + streamCancel() + <-streamDone + return err +} + +func streamLines(ctx context.Context, done chan<- struct{}, + logger Logger, stdout, stderr <-chan string, +) { + defer close(done) + + var line string + + for { + select { + case <-ctx.Done(): + return + case line = <-stdout: + logger.Info(line) + case line = <-stderr: + logger.Error(line) + } + } } diff --git a/internal/portforward/service/command_test.go b/internal/portforward/service/command_test.go new file mode 100644 index 000000000..59b471aa2 --- /dev/null +++ b/internal/portforward/service/command_test.go @@ -0,0 +1,28 @@ +//go:build linux + +package service + +import ( + "context" + "testing" + + gomock "github.com/golang/mock/gomock" + "github.com/qdm12/gluetun/internal/command" + "github.com/stretchr/testify/require" +) + +func Test_Service_runUpCommand(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + ctx := context.Background() + cmder := command.New() + const commandTemplate = `/bin/sh -c "echo {{PORTS}}"` + ports := []uint16{1234, 5678} + logger := NewMockLogger(ctrl) + logger.EXPECT().Info("1234,5678") + + err := runUpCommand(ctx, cmder, logger, commandTemplate, ports) + + require.NoError(t, err) +} diff --git a/internal/portforward/service/interfaces.go b/internal/portforward/service/interfaces.go index 9a1f7c040..01876be82 100644 --- a/internal/portforward/service/interfaces.go +++ b/internal/portforward/service/interfaces.go @@ -3,6 +3,7 @@ package service import ( "context" "net/netip" + "os/exec" "github.com/qdm12/gluetun/internal/provider/utils" ) @@ -32,3 +33,8 @@ type PortForwarder interface { ports []uint16, err error) KeepPortForward(ctx context.Context, objects utils.PortForwardObjects) (err error) } + +type Cmder interface { + Start(cmd *exec.Cmd) (stdoutLines, stderrLines <-chan string, + waitError <-chan error, startErr error) +} diff --git a/internal/portforward/service/mocks_generate_test.go b/internal/portforward/service/mocks_generate_test.go new file mode 100644 index 000000000..776a506df --- /dev/null +++ b/internal/portforward/service/mocks_generate_test.go @@ -0,0 +1,3 @@ +package service + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Logger diff --git a/internal/portforward/service/mocks_test.go b/internal/portforward/service/mocks_test.go new file mode 100644 index 000000000..69fa3a0ce --- /dev/null +++ b/internal/portforward/service/mocks_test.go @@ -0,0 +1,82 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/portforward/service (interfaces: Logger) + +// Package service is a generated GoMock package. +package service + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debug mocks base method. +func (m *MockLogger) Debug(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0) +} + +// Debug indicates an expected call of Debug. +func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0) +} + +// Error mocks base method. +func (m *MockLogger) Error(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Error", arg0) +} + +// Error indicates an expected call of Error. +func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0) +} + +// Info mocks base method. +func (m *MockLogger) Info(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Info", arg0) +} + +// Info indicates an expected call of Info. +func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0) +} + +// Warn mocks base method. +func (m *MockLogger) Warn(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Warn", arg0) +} + +// Warn indicates an expected call of Warn. +func (mr *MockLoggerMockRecorder) Warn(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0) +} diff --git a/internal/portforward/service/service.go b/internal/portforward/service/service.go index 077e027e6..579b37397 100644 --- a/internal/portforward/service/service.go +++ b/internal/portforward/service/service.go @@ -19,6 +19,7 @@ type Service struct { client *http.Client portAllower PortAllower logger Logger + cmder Cmder // Internal channels and locks startStopMutex sync.Mutex keepPortCancel context.CancelFunc @@ -26,7 +27,7 @@ type Service struct { } func New(settings Settings, routing Routing, client *http.Client, - portAllower PortAllower, logger Logger, puid, pgid int, + portAllower PortAllower, logger Logger, cmder Cmder, puid, pgid int, ) *Service { return &Service{ // Fixed parameters @@ -38,6 +39,7 @@ func New(settings Settings, routing Routing, client *http.Client, client: client, portAllower: portAllower, logger: logger, + cmder: cmder, } } diff --git a/internal/portforward/service/settings.go b/internal/portforward/service/settings.go index 7af567e43..b44d790c6 100644 --- a/internal/portforward/service/settings.go +++ b/internal/portforward/service/settings.go @@ -12,7 +12,7 @@ type Settings struct { Enabled *bool PortForwarder PortForwarder Filepath string - Command string + UpCommand string Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example ServerName string // needed for PIA CanPortForward bool // needed for PIA @@ -25,7 +25,7 @@ func (s Settings) Copy() (copied Settings) { copied.Enabled = gosettings.CopyPointer(s.Enabled) copied.PortForwarder = s.PortForwarder copied.Filepath = s.Filepath - copied.Command = s.Command + copied.UpCommand = s.UpCommand copied.Interface = s.Interface copied.ServerName = s.ServerName copied.CanPortForward = s.CanPortForward @@ -39,7 +39,7 @@ func (s *Settings) OverrideWith(update Settings) { s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled) s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder) s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath) - s.Command = gosettings.OverrideWithComparable(s.Command, update.Command) + s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand) s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface) s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName) s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward) diff --git a/internal/portforward/service/start.go b/internal/portforward/service/start.go index 5b16e5ca7..c7fb9bbf9 100644 --- a/internal/portforward/service/start.go +++ b/internal/portforward/service/start.go @@ -69,18 +69,18 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error) return nil, fmt.Errorf("writing port file: %w", err) } - if s.settings.Command != "" { - err = s.runUpCommand(ctx, ports) - if err != nil { - _ = s.cleanup() - return nil, fmt.Errorf("running port forward command: %w", err) - } - } - s.portMutex.Lock() s.ports = ports s.portMutex.Unlock() + if s.settings.UpCommand != "" { + err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports) + if err != nil { + err = fmt.Errorf("running up command: %w", err) + s.logger.Error(err.Error()) + } + } + keepPortCtx, keepPortCancel := context.WithCancel(context.Background()) s.keepPortCancel = keepPortCancel runErrorCh := make(chan error)