diff --git a/html.go b/html.go index f2d24a9..c04d601 100644 --- a/html.go +++ b/html.go @@ -43,8 +43,7 @@ func appendAttrs(b []byte, attrs []template.HTMLAttr) []byte { } func WriteHtmlTag(w io.Writer, jid jid.Jid, htmlTag, typeAttr, valueAttr string, attrs []template.HTMLAttr) (err error) { - b := make([]byte, 0, 64) - b = jid.AppendStartTagAttr(b, htmlTag) + b := jid.AppendStartTagAttr(nil, htmlTag) if typeAttr != "" { b = append(b, ` type=`...) b = strconv.AppendQuote(b, typeAttr) @@ -66,13 +65,12 @@ func WriteHtmlInput(w io.Writer, jid jid.Jid, typeAttr, valueAttr string, attrs func WriteHtmlInner(w io.Writer, jid jid.Jid, htmlTag, typeAttr string, innerHtml template.HTML, attrs ...template.HTMLAttr) (err error) { if err = WriteHtmlTag(w, jid, htmlTag, typeAttr, "", attrs); err == nil { if innerHtml != "" || needClosingTag(htmlTag) { - if _, err = w.Write([]byte(innerHtml)); err == nil { - var b []byte - b = append(b, "') - _, err = w.Write(b) - } + var b []byte + b = append(b, innerHtml...) + b = append(b, "') + _, err = w.Write(b) } } return diff --git a/request.go b/request.go index 9be66e6..21360ab 100644 --- a/request.go +++ b/request.go @@ -417,12 +417,12 @@ func (rq *Request) Done() (ch <-chan struct{}) { } // process is the main message processing loop. Will unsubscribe broadcastMsgCh and close outboundMsgCh on exit. -func (rq *Request) process(broadcastMsgCh chan Message, incomingMsgCh <-chan wsMsg, outboundCh chan<- string) { +func (rq *Request) process(broadcastMsgCh chan Message, incomingMsgCh <-chan wsMsg, outboundMsgCh chan<- wsMsg) { jawsDoneCh := rq.Jaws.Done() ctxDoneCh := rq.Done() eventDoneCh := make(chan struct{}) - eventCallCh := make(chan eventFnCall, cap(outboundCh)) - go rq.eventCaller(eventCallCh, outboundCh, eventDoneCh) + eventCallCh := make(chan eventFnCall, cap(outboundMsgCh)) + go rq.eventCaller(eventCallCh, outboundMsgCh, eventDoneCh) defer func() { rq.Jaws.unsubscribe(broadcastMsgCh) @@ -433,7 +433,7 @@ func (rq *Request) process(broadcastMsgCh chan Message, incomingMsgCh <-chan wsM case <-eventCallCh: case <-incomingMsgCh: case <-eventDoneCh: - close(outboundCh) + close(outboundMsgCh) if x := recover(); x != nil { var err error var ok bool @@ -455,7 +455,7 @@ func (rq *Request) process(broadcastMsgCh chan Message, incomingMsgCh <-chan wsM var ok bool if len(wsQueue) > 0 { - wsQueue = rq.sendQueue(outboundCh, wsQueue) + wsQueue = rq.sendQueue(outboundMsgCh, wsQueue) } // Empty the dirty tags list and call JawsUpdate() @@ -475,7 +475,7 @@ func (rq *Request) process(broadcastMsgCh chan Message, incomingMsgCh <-chan wsM rq.mu.RUnlock() if len(wsQueue) > 0 { - wsQueue = rq.sendQueue(outboundCh, wsQueue) + wsQueue = rq.sendQueue(outboundMsgCh, wsQueue) } select { @@ -629,21 +629,19 @@ func (rq *Request) queueEvent(eventCallCh chan eventFnCall, call eventFnCall) { } } -func (rq *Request) wsSend(outboundCh chan<- string, s string) { +func (rq *Request) wsSend(outboundMsgCh chan<- wsMsg, msg wsMsg) { select { case <-rq.Done(): - case outboundCh <- s: + case outboundMsgCh <- msg: default: - panic(fmt.Errorf("jaws: %v: outbound message channel is full (%d) sending %s", rq, len(outboundCh), s)) + panic(fmt.Errorf("jaws: %v: outbound message channel is full (%d) sending %s", rq, len(outboundMsgCh), msg)) } } -func (rq *Request) sendQueue(outboundCh chan<- string, wsQueue []wsMsg) []wsMsg { - var sb strings.Builder +func (rq *Request) sendQueue(outboundMsgCh chan<- wsMsg, wsQueue []wsMsg) []wsMsg { for _, msg := range wsQueue { - sb.WriteString(msg.Format()) + rq.wsSend(outboundMsgCh, msg) } - rq.wsSend(outboundCh, sb.String()) return wsQueue[:0] } @@ -700,14 +698,14 @@ func (rq *Request) makeUpdateList() (todo []*Element) { } // eventCaller calls event functions -func (rq *Request) eventCaller(eventCallCh <-chan eventFnCall, outboundCh chan<- string, eventDoneCh chan<- struct{}) { +func (rq *Request) eventCaller(eventCallCh <-chan eventFnCall, outboundMsgCh chan<- wsMsg, eventDoneCh chan<- struct{}) { defer close(eventDoneCh) for call := range eventCallCh { if err := rq.callAllEventHandlers(call.jid, call.wht, call.data); err != nil { var m wsMsg m.FillAlert(err) select { - case outboundCh <- m.Format(): + case outboundMsgCh <- m: default: _ = rq.Jaws.Log(fmt.Errorf("jaws: outboundMsgCh full sending event error '%s'", err.Error())) } diff --git a/request_test.go b/request_test.go index dd17725..5e6ddba 100644 --- a/request_test.go +++ b/request_test.go @@ -18,10 +18,10 @@ import ( const testTimeout = time.Second * 3 -func fillWsCh(ch chan string) { +func fillWsCh(ch chan wsMsg) { for { select { - case ch <- "": + case ch <- wsMsg{}: default: return } @@ -72,9 +72,7 @@ func TestRequest_SendArrivesOk(t *testing.T) { select { case <-time.NewTimer(testTimeout).C: is.Fail() - case msgstr := <-rq.outCh: - msg, ok := wsParse([]byte(msgstr)) - is.True(ok) + case msg := <-rq.outCh: elem := rq.getElementByJid(jid) is.True(elem != nil) is.Equal(msg, wsMsg{Jid: elem.jid, Data: "bar", What: what.Inner}) @@ -214,7 +212,7 @@ func TestRequest_Trigger(t *testing.T) { case <-th.C: th.Timeout() case msg := <-rq.outCh: - th.Equal(msg, (&wsMsg{ + th.Equal(msg.Format(), (&wsMsg{ Data: "danger\nomg", Jid: jid.Jid(0), What: what.Alert, @@ -393,7 +391,8 @@ func TestRequest_Alert(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq1.outCh: + case msg := <-rq1.outCh: + s := msg.Format() if s != "Alert\t\t\"info\\n\\nnot\\tescaped\"\n" { t.Errorf("%q", s) } @@ -416,7 +415,8 @@ func TestRequest_Redirect(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq1.outCh: + case msg := <-rq1.outCh: + s := msg.Format() if s != "Redirect\t\t\"some-url\"\n" { t.Errorf("%q", s) } @@ -437,7 +437,8 @@ func TestRequest_AlertError(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\n<html>\\nshould-be-escaped\"\n" { t.Errorf("%q", s) } @@ -478,19 +479,43 @@ func TestRequest_DeleteByTag(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq1.outCh: - if s != "Delete\tJid.1\t\"\"\nDelete\tJid.3\t\"\"\n" { + case msg := <-rq1.outCh: + s := msg.Format() + if s != "Delete\tJid.1\t\"\"\n" { t.Errorf("%q", s) } } + select { case <-th.C: th.Timeout() - case s := <-rq2.outCh: - if s != "Delete\tJid.4\t\"\"\nDelete\tJid.6\t\"\"\n" { + case msg := <-rq1.outCh: + s := msg.Format() + if s != "Delete\tJid.3\t\"\"\n" { t.Errorf("%q", s) } } + + select { + case <-th.C: + th.Timeout() + case msg := <-rq2.outCh: + s := msg.Format() + if s != "Delete\tJid.4\t\"\"\n" { + t.Errorf("%q", s) + } + } + + select { + case <-th.C: + th.Timeout() + case msg := <-rq2.outCh: + s := msg.Format() + if s != "Delete\tJid.6\t\"\"\n" { + t.Errorf("%q", s) + } + } + } func TestRequest_HtmlIdBroadcast(t *testing.T) { @@ -508,7 +533,8 @@ func TestRequest_HtmlIdBroadcast(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq1.outCh: + case msg := <-rq1.outCh: + s := msg.Format() if s != "Inner\tfooId\t\"inner\"\n" { t.Errorf("%q", s) } @@ -516,7 +542,8 @@ func TestRequest_HtmlIdBroadcast(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq2.outCh: + case msg := <-rq2.outCh: + s := msg.Format() if s != "Inner\tfooId\t\"inner\"\n" { t.Errorf("%q", s) } diff --git a/testjaws_test.go b/testjaws_test.go index 545adfa..a502ac3 100644 --- a/testjaws_test.go +++ b/testjaws_test.go @@ -35,7 +35,7 @@ type testRequest struct { readyCh chan struct{} doneCh chan struct{} inCh chan wsMsg - outCh chan string + outCh chan wsMsg bcastCh chan Message ctx context.Context cancel context.CancelFunc @@ -70,7 +70,7 @@ func (tj *testJaws) newRequest(hr *http.Request) (tr *testRequest) { readyCh: make(chan struct{}), doneCh: make(chan struct{}), inCh: make(chan wsMsg), - outCh: make(chan string, cap(bcastCh)), + outCh: make(chan wsMsg, cap(bcastCh)), bcastCh: bcastCh, ctx: ctx, cancel: cancel, diff --git a/uicheckbox_test.go b/uicheckbox_test.go index 62eadb9..46edd29 100644 --- a/uicheckbox_test.go +++ b/uicheckbox_test.go @@ -42,7 +42,8 @@ func TestRequest_Checkbox(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Value\tJid.1\t\"true\"\n" { t.Errorf("%q", s) } @@ -58,7 +59,8 @@ func TestRequest_Checkbox(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nstrconv.ParseBool: parsing "omg": invalid syntax\"\n" { t.Errorf("wrong Alert: %q", s) } @@ -69,7 +71,8 @@ func TestRequest_Checkbox(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nmeh\"\n" { t.Errorf("wrong Alert: %q", s) } diff --git a/uidate_test.go b/uidate_test.go index 5b05e9e..c861411 100644 --- a/uidate_test.go +++ b/uidate_test.go @@ -46,7 +46,8 @@ func TestRequest_Date(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != fmt.Sprintf("Value\tJid.1\t\"%s\"\n", val.Format(ISO8601)) { t.Error("wrong Value") } @@ -62,7 +63,8 @@ func TestRequest_Date(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nparsing time "omg" as "2006-01-02": cannot parse "omg" as "2006"\"\n" { t.Errorf("wrong Alert: %q", s) } @@ -73,7 +75,8 @@ func TestRequest_Date(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nmeh\"\n" { t.Errorf("wrong Alert: %q", s) } diff --git a/uiimg_test.go b/uiimg_test.go index 46da5e1..89b662d 100644 --- a/uiimg_test.go +++ b/uiimg_test.go @@ -25,7 +25,8 @@ func TestRequest_Img(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "SAttr\tJid.1\t\"src\\n\\\"unquoted.jpg\\\"\"\n" { t.Error(strconv.Quote(s)) } diff --git a/uinumber_test.go b/uinumber_test.go index 3d3b1ff..e45cd92 100644 --- a/uinumber_test.go +++ b/uinumber_test.go @@ -43,7 +43,8 @@ func TestRequest_Number(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != fmt.Sprintf("Value\tJid.1\t\"%v\"\n", val) { t.Error("wrong Value") } @@ -59,7 +60,8 @@ func TestRequest_Number(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nstrconv.ParseFloat: parsing "omg": invalid syntax\"\n" { t.Errorf("wrong Alert: %q", s) } @@ -70,7 +72,8 @@ func TestRequest_Number(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nmeh\"\n" { t.Errorf("wrong Alert: %q", s) } diff --git a/uioption_test.go b/uioption_test.go index 4f68351..8fb88a2 100644 --- a/uioption_test.go +++ b/uioption_test.go @@ -30,7 +30,8 @@ func TestUiOption(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "RAttr\tJid.1\t\"selected\"\n" { t.Errorf("%q", s) } @@ -41,7 +42,8 @@ func TestUiOption(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "SAttr\tJid.1\t\"selected\\n\"\n" { t.Errorf("%q", s) } diff --git a/uirange_test.go b/uirange_test.go index d62b54a..59e9e9a 100644 --- a/uirange_test.go +++ b/uirange_test.go @@ -38,7 +38,8 @@ func TestRequest_Range(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Value\tJid.1\t\"2.3\"\n" { t.Error(s) } @@ -55,7 +56,8 @@ func TestRequest_Range(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nmeh\"\n" { t.Errorf("wrong Alert: %q", s) } diff --git a/uiselect_test.go b/uiselect_test.go index b3caf9d..f5b18b0 100644 --- a/uiselect_test.go +++ b/uiselect_test.go @@ -82,7 +82,8 @@ func TestRequest_Select(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Value\tJid.1\t\"\"\n" { t.Error("wrong Value") } @@ -100,7 +101,8 @@ func TestRequest_Select(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nmeh\"\n" { t.Errorf("wrong Alert: %q", s) } diff --git a/uitext_test.go b/uitext_test.go index ce886e7..0170ce6 100644 --- a/uitext_test.go +++ b/uitext_test.go @@ -38,7 +38,8 @@ func TestRequest_Text(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Value\tJid.1\t\"quux\"\n" { t.Error("wrong Value") } @@ -55,7 +56,8 @@ func TestRequest_Text(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Alert\t\t\"danger\\nmeh\"\n" { t.Errorf("wrong Alert: %q", s) } diff --git a/uitextarea_test.go b/uitextarea_test.go index a76a9ea..0c78540 100644 --- a/uitextarea_test.go +++ b/uitextarea_test.go @@ -37,7 +37,8 @@ func TestRequest_Textarea(t *testing.T) { select { case <-th.C: th.Timeout() - case s := <-rq.outCh: + case msg := <-rq.outCh: + s := msg.Format() if s != "Value\tJid.1\t\"quux\"\n" { t.Fail() } diff --git a/ws.go b/ws.go index 7fe0fb9..d708514 100644 --- a/ws.go +++ b/ws.go @@ -2,6 +2,7 @@ package jaws import ( "context" + "io" "net/http" "strings" @@ -38,10 +39,10 @@ func (rq *Request) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err = rq.onConnect(); err == nil { incomingMsgCh := make(chan wsMsg) broadcastMsgCh := rq.Jaws.subscribe(rq, 4+len(rq.elems)*4) - outboundCh := make(chan string, cap(broadcastMsgCh)) + outboundMsgCh := make(chan wsMsg, cap(broadcastMsgCh)) go wsReader(rq.ctx, rq.cancelFn, rq.Jaws.Done(), incomingMsgCh, ws) // closes incomingMsgCh - go wsWriter(rq.ctx, rq.cancelFn, rq.Jaws.Done(), outboundCh, ws) // calls ws.Close() - rq.process(broadcastMsgCh, incomingMsgCh, outboundCh) // unsubscribes broadcastMsgCh, closes outboundMsgCh + go wsWriter(rq.ctx, rq.cancelFn, rq.Jaws.Done(), outboundMsgCh, ws) // calls ws.Close() + rq.process(broadcastMsgCh, incomingMsgCh, outboundMsgCh) // unsubscribes broadcastMsgCh, closes outboundMsgCh } else { defer ws.Close(websocket.StatusNormalClosure, err.Error()) var msg wsMsg @@ -82,7 +83,7 @@ func wsReader(ctx context.Context, ccf context.CancelCauseFunc, jawsDoneCh <-cha // wsWriter reads JaWS messages from outboundMsgCh, formats them and writes them to the websocket. // // Closes the websocket on exit. -func wsWriter(ctx context.Context, ccf context.CancelCauseFunc, jawsDoneCh <-chan struct{}, outboundCh <-chan string, ws *websocket.Conn) { +func wsWriter(ctx context.Context, ccf context.CancelCauseFunc, jawsDoneCh <-chan struct{}, outboundMsgCh <-chan wsMsg, ws *websocket.Conn) { defer ws.Close(websocket.StatusNormalClosure, "") var err error for err == nil { @@ -91,14 +92,34 @@ func wsWriter(ctx context.Context, ccf context.CancelCauseFunc, jawsDoneCh <-cha return case <-jawsDoneCh: return - case msg, ok := <-outboundCh: + case msg, ok := <-outboundMsgCh: if !ok { return } - err = ws.Write(ctx, websocket.MessageText, []byte(msg)) + var wc io.WriteCloser + if wc, err = ws.Writer(ctx, websocket.MessageText); err == nil { + err = wsWriteData(wc, msg, outboundMsgCh) + } } } if ccf != nil { ccf(err) } } + +func wsWriteData(wc io.WriteCloser, firstMsg wsMsg, outboundMsgCh <-chan wsMsg) (err error) { + defer wc.Close() + b := firstMsg.Append(nil) + for { + select { + case msg := <-outboundMsgCh: + b = msg.Append(b) + if len(b) < 2*1024 { + continue + } + default: + } + _, err = wc.Write(b) + return + } +} diff --git a/ws_test.go b/ws_test.go index ae2c4ed..c9b3438 100644 --- a/ws_test.go +++ b/ws_test.go @@ -250,7 +250,7 @@ func TestWriter_SendsThePayload(t *testing.T) { ts := newTestServer() defer ts.Close() - outCh := make(chan string) + outCh := make(chan wsMsg) defer close(outCh) client, server := Pipe() @@ -270,7 +270,7 @@ func TestWriter_SendsThePayload(t *testing.T) { select { case <-th.C: th.Timeout() - case outCh <- msg.Format(): + case outCh <- msg: } select { @@ -296,13 +296,70 @@ func TestWriter_SendsThePayload(t *testing.T) { } } +func TestWriter_ConcatenatesMessages(t *testing.T) { + th := newTestHelper(t) + ts := newTestServer() + defer ts.Close() + + outCh := make(chan wsMsg, 2) + defer close(outCh) + client, server := Pipe() + + go wsWriter(ts.ctx, nil, ts.jw.Done(), outCh, server) + + var mt websocket.MessageType + var b []byte + var err error + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + mt, b, err = client.Read(ts.ctx) + ts.cancel() + }() + + msg := wsMsg{Jid: Jid(1234)} + select { + case <-th.C: + th.Timeout() + case outCh <- msg: + } + select { + case <-th.C: + th.Timeout() + case outCh <- msg: + } + + select { + case <-th.C: + th.Timeout() + case <-doneCh: + } + + if err != nil { + t.Error(err) + } + if mt != websocket.MessageText { + t.Error(mt) + } + want := msg.Format() + msg.Format() + if string(b) != want { + t.Error(string(b)) + } + + select { + case <-th.C: + th.Timeout() + case <-client.CloseRead(ts.ctx).Done(): + } +} + func TestWriter_RespectsContext(t *testing.T) { th := newTestHelper(t) ts := newTestServer() defer ts.Close() doneCh := make(chan struct{}) - outCh := make(chan string) + outCh := make(chan wsMsg) defer close(outCh) client, server := Pipe() client.CloseRead(context.Background()) @@ -328,7 +385,7 @@ func TestWriter_RespectsJawsDone(t *testing.T) { defer ts.Close() doneCh := make(chan struct{}) - outCh := make(chan string) + outCh := make(chan wsMsg) defer close(outCh) client, server := Pipe() client.CloseRead(ts.ctx) @@ -353,7 +410,7 @@ func TestWriter_RespectsOutboundClosed(t *testing.T) { defer ts.Close() doneCh := make(chan struct{}) - outCh := make(chan string) + outCh := make(chan wsMsg) client, server := Pipe() client.CloseRead(ts.ctx) @@ -381,7 +438,7 @@ func TestWriter_ReportsError(t *testing.T) { defer ts.Close() doneCh := make(chan struct{}) - outCh := make(chan string) + outCh := make(chan wsMsg) client, server := Pipe() client.CloseRead(ts.ctx) server.Close(websocket.StatusNormalClosure, "") @@ -395,7 +452,7 @@ func TestWriter_ReportsError(t *testing.T) { select { case <-th.C: th.Timeout() - case outCh <- msg.Format(): + case outCh <- msg: } select {