diff --git a/go.sum b/go.sum index ab793c61..3153ad53 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,6 @@ github.com/ilyakaznacheev/cleanenv v1.5.0 h1:0VNZXggJE2OYdXE87bfSSwGxeiGt9moSR2l github.com/ilyakaznacheev/cleanenv v1.5.0/go.mod h1:a5aDzaJrLCQZsazHol1w8InnDcOX0OColm64SlIi6gk= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/open-amt-cloud-toolkit/go-wsman-messages/v2 v2.8.0 h1:4HlR1Mr1KELUPii2IXJ/WDS/j8bVA8aCjc30nJFF1rY= -github.com/open-amt-cloud-toolkit/go-wsman-messages/v2 v2.8.0/go.mod h1:Z/zRJrraqGMxVTAqVRKE2QgeySouZP2vwkCy9u8UYb0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/internal/lm/engine.go b/internal/lm/engine.go index e405c7d7..5dade0cd 100644 --- a/internal/lm/engine.go +++ b/internal/lm/engine.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/binary" "rpc/pkg/pthi" + "sync" "time" "github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/apf" @@ -22,7 +23,7 @@ type LMEConnection struct { retries int } -func NewLMEConnection(data chan []byte, errors chan error, status chan bool) *LMEConnection { +func NewLMEConnection(data chan []byte, errors chan error, status chan bool, wg *sync.WaitGroup) *LMEConnection { lme := &LMEConnection{ ourChannel: 1, } @@ -32,6 +33,7 @@ func NewLMEConnection(data chan []byte, errors chan error, status chan bool) *LM ErrorBuffer: errors, Tempdata: []byte{}, Status: status, + WaitGroup: wg, } return lme @@ -65,7 +67,7 @@ func (lme *LMEConnection) Connect() error { } else { lme.ourChannel = channel } - + lme.Session.WaitGroup.Add(1) bin_buf := apf.ChannelOpen(lme.ourChannel) err := lme.Command.Send(bin_buf.Bytes(), uint32(bin_buf.Len())) if err != nil { @@ -133,15 +135,6 @@ func (lme *LMEConnection) Listen() { lme.Session.DataBuffer <- lme.Session.Tempdata lme.Session.Tempdata = []byte{} var bin_buf bytes.Buffer - // var windowAdjust apf.APF_CHANNEL_WINDOW_ADJUST_MESSAGE - // if lme.Session.RXWindow > 1024 { // TODO: Check this - // windowAdjust = apf.ChannelWindowAdjust(lme.Session.RecipientChannel, lme.Session.RXWindow) - // lme.Session.RXWindow = 0 - // binary.Write(&bin_buf, binary.BigEndian, windowAdjust.MessageType) - // binary.Write(&bin_buf, binary.BigEndian, windowAdjust.RecipientChannel) - // lme.Command.Call(bin_buf.Bytes(), uint32(bin_buf.Len())) - // } - channelData := apf.ChannelClose(lme.Session.SenderChannel) binary.Write(&bin_buf, binary.BigEndian, channelData.MessageType) binary.Write(&bin_buf, binary.BigEndian, channelData.RecipientChannel) @@ -153,7 +146,7 @@ func (lme *LMEConnection) Listen() { result2, bytesRead, err2 := lme.Command.Receive() if bytesRead == 0 || err2 != nil { log.Trace("NO MORE DATA TO READ") - break + // break } else { result := apf.Process(result2, lme.Session) if result.Len() != 0 { @@ -167,7 +160,7 @@ func (lme *LMEConnection) Listen() { } } -// Close closes the LMS socket connection +// Close closes the LME connection func (lme *LMEConnection) Close() error { log.Debug("closing connection to lme") lme.Command.Close() diff --git a/internal/local/amt/localTransport.go b/internal/local/amt/localTransport.go index af846511..a0da25df 100644 --- a/internal/local/amt/localTransport.go +++ b/internal/local/amt/localTransport.go @@ -12,27 +12,30 @@ import ( "io" "net/http" "rpc/internal/lm" + "sync" "github.com/sirupsen/logrus" ) -// LocalTransport - Your custom net.Conn implementation type LocalTransport struct { - local lm.LocalMananger - data chan []byte - errors chan error - status chan bool + local lm.LocalMananger + data chan []byte + errors chan error + status chan bool + waitGroup *sync.WaitGroup } func NewLocalTransport() *LocalTransport { lmDataChannel := make(chan []byte) lmErrorChannel := make(chan error) lmStatus := make(chan bool) + waiter := &sync.WaitGroup{} lm := &LocalTransport{ - local: lm.NewLMEConnection(lmDataChannel, lmErrorChannel, lmStatus), - data: lmDataChannel, - errors: lmErrorChannel, - status: lmStatus, + local: lm.NewLMEConnection(lmDataChannel, lmErrorChannel, lmStatus, waiter), + data: lmDataChannel, + errors: lmErrorChannel, + status: lmStatus, + waitGroup: waiter, } // defer lm.local.Close() // defer close(lmDataChannel) @@ -49,20 +52,18 @@ func NewLocalTransport() *LocalTransport { // Custom dialer function func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) { - //Something comes here...Maybe - go l.local.Listen() - // send channel open err := l.local.Connect() + //Something comes here...Maybe + go l.local.Listen() if err != nil { logrus.Error(err) return nil, err } // wait for channel open confirmation - <-l.status + l.waitGroup.Wait() logrus.Trace("Channel open confirmation received") - // Serialize the HTTP request to raw form rawRequest, err := serializeHTTPRequest(r) if err != nil { @@ -71,22 +72,30 @@ func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) { } var responseReader *bufio.Reader - // send our data to LMX - err = l.local.Send(rawRequest) + + err = l.local.Send([]byte(rawRequest)) if err != nil { logrus.Error(err) return nil, err } - for dataFromLM := range l.data { - if len(dataFromLM) > 0 { - logrus.Debug("received data from LME") - logrus.Trace(string(dataFromLM)) - - // /<-l.status - responseReader = bufio.NewReader(bytes.NewReader(dataFromLM)) - break +Loop: + for { + select { + case dataFromLM := <-l.data: + if len(dataFromLM) > 0 { + logrus.Debug("received data from LME") + logrus.Trace(string(dataFromLM)) + responseReader = bufio.NewReader(bytes.NewReader(dataFromLM)) + break Loop + } + case errFromLMS := <-l.errors: + if errFromLMS != nil { + logrus.Error("error from LMS") + break Loop + } } + } response, err := http.ReadResponse(responseReader, r) @@ -101,6 +110,8 @@ func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) { func serializeHTTPRequest(r *http.Request) ([]byte, error) { var reqBuffer bytes.Buffer + r.Header.Set("Transfer-Encoding", "chunked") + // Write request line reqLine := fmt.Sprintf("%s %s %s\r\n", r.Method, r.URL.RequestURI(), r.Proto) reqBuffer.WriteString(reqLine) @@ -115,8 +126,12 @@ func serializeHTTPRequest(r *http.Request) ([]byte, error) { if err != nil { return nil, err } + length := fmt.Sprintf("%x", len(bodyBytes)) + bodyBytes = append([]byte(length+"\r\n"), bodyBytes...) + bodyBytes = append(bodyBytes, []byte("\r\n0\r\n\r\n")...) // Important: Replace the body so it can be read again later if needed r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + reqBuffer.Write(bodyBytes) } diff --git a/internal/local/amt/wsman.go b/internal/local/amt/wsman.go index cbd73b5b..f05a23cc 100644 --- a/internal/local/amt/wsman.go +++ b/internal/local/amt/wsman.go @@ -7,6 +7,7 @@ package amt import ( "encoding/base64" + "net" "rpc/pkg/utils" "github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman" @@ -29,6 +30,7 @@ import ( "github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman/ips/hostbasedsetup" "github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman/ips/ieee8021x" "github.com/open-amt-cloud-toolkit/go-wsman-messages/v2/pkg/wsman/ips/optin" + "github.com/sirupsen/logrus" ) type WSMANer interface { @@ -94,7 +96,6 @@ func NewGoWSMANMessages(lmsAddress string) *GoWSMANMessages { } func (g *GoWSMANMessages) SetupWsmanClient(username string, password string, logAMTMessages bool) { - clientParams := client.Parameters{ Target: g.target, Username: username, @@ -103,6 +104,19 @@ func (g *GoWSMANMessages) SetupWsmanClient(username string, password string, log UseTLS: false, LogAMTMessages: logAMTMessages, } + logrus.Info("Attempting to connect to LMS...") + port := utils.LMSPort + if clientParams.UseTLS { + port = client.TLSPort + } + con, err := net.Dial("tcp4", utils.LMSAddress+":"+port) + if err != nil { + logrus.Info("Failed to connect to LMS, using local transport instead.") + clientParams.Transport = NewLocalTransport() + } else { + logrus.Info("Successfully connected to LMS.") + con.Close() + } g.wsmanMessages = wsman.NewMessages(clientParams) } diff --git a/internal/rps/executor.go b/internal/rps/executor.go index 03ad84c1..44e038d9 100644 --- a/internal/rps/executor.go +++ b/internal/rps/executor.go @@ -10,6 +10,7 @@ import ( "rpc/internal/flags" "rpc/internal/lm" "rpc/pkg/utils" + "sync" "syscall" log "github.com/sirupsen/logrus" @@ -23,6 +24,7 @@ type Executor struct { data chan []byte errors chan error status chan bool + waitGroup *sync.WaitGroup } func NewExecutor(flags flags.Flags) (Executor, error) { @@ -35,6 +37,7 @@ func NewExecutor(flags flags.Flags) (Executor, error) { localManagement: lm.NewLMSConnection(utils.LMSAddress, utils.LMSPort, lmDataChannel, lmErrorChannel), data: lmDataChannel, errors: lmErrorChannel, + waitGroup: &sync.WaitGroup{}, } // TEST CONNECTION TO SEE IF LMS EXISTS @@ -44,7 +47,7 @@ func NewExecutor(flags flags.Flags) (Executor, error) { // client.localManagement.Close() log.Trace("LMS not running. Using LME Connection\n") client.status = make(chan bool) - client.localManagement = lm.NewLMEConnection(lmDataChannel, lmErrorChannel, client.status) + client.localManagement = lm.NewLMEConnection(lmDataChannel, lmErrorChannel, client.status, client.waitGroup) client.isLME = true client.localManagement.Initialize() } else { @@ -131,7 +134,7 @@ func (e Executor) HandleDataFromRPS(dataFromServer []byte) bool { } if e.isLME { // wait for channel open confirmation - <-e.status + e.waitGroup.Wait() log.Trace("Channel open confirmation received") } else { //with LMS we open/close websocket on every request, so setup close for when we're done handling LMS data @@ -150,7 +153,7 @@ func (e Executor) HandleDataFromRPS(dataFromServer []byte) bool { case dataFromLM := <-e.data: e.HandleDataFromLM(dataFromLM) if e.isLME { - <-e.status + e.waitGroup.Wait() } return false case errFromLMS := <-e.errors: