diff --git a/jaws.go b/jaws.go index fac4970..775f352 100644 --- a/jaws.go +++ b/jaws.go @@ -475,12 +475,12 @@ func (jw *Jaws) Alert(lvl, msg string) { // Count returns the number of requests waiting for their WebSocket callbacks. func (jw *Jaws) Pending() (n int) { jw.mu.RLock() + defer jw.mu.RUnlock() for _, rq := range jw.requests { - if !rq.claimed { + if !rq.claimed.Load() { n++ } } - jw.mu.RUnlock() return } @@ -576,11 +576,11 @@ func (jw *Jaws) unsubscribe(msgCh chan Message) { } func (jw *Jaws) maintenance(requestTimeout time.Duration) { - deadline := time.Now().Add(-requestTimeout) jw.mu.Lock() defer jw.mu.Unlock() + now := time.Now() for _, rq := range jw.requests { - if rq.maintenance(deadline) { + if rq.maintenance(now, requestTimeout) { jw.recycleLocked(rq) } } @@ -716,7 +716,7 @@ func (jw *Jaws) Append(target any, html template.HTML) { func (jw *Jaws) getRequestLocked(jawsKey uint64, hr *http.Request) (rq *Request) { rq = jw.reqPool.Get().(*Request) rq.JawsKey = jawsKey - rq.Created = time.Now() + rq.lastWrite = time.Now() rq.initial = hr rq.ctx, rq.cancelFn = context.WithCancelCause(context.Background()) if hr != nil { diff --git a/jaws_test.go b/jaws_test.go index a7f70e8..0da1aae 100644 --- a/jaws_test.go +++ b/jaws_test.go @@ -6,6 +6,7 @@ import ( "context" "errors" "html/template" + "io" "log" "net/http" "net/http/httptest" @@ -242,6 +243,10 @@ func TestJaws_UseRequest(t *testing.T) { th.Equal(rq2, rq2ret) th.Equal(jw.Pending(), 1) + rqfail = jw.UseRequest(rq2.JawsKey, &http.Request{RemoteAddr: "10.0.0.2:1214"}) // already claimed + th.Equal(rqfail, nil) + th.Equal(jw.Pending(), 1) + rq1ret := jw.UseRequest(rq1.JawsKey, nil) th.Equal(rq1, rq1ret) th.Equal(jw.Pending(), 0) @@ -275,13 +280,10 @@ func TestJaws_CleansUpUnconnected(t *testing.T) { for i := 0; i < numReqs; i++ { rq := jw.NewRequest(hr) if (i % (numReqs / 5)) == 0 { - elem := rq.NewElement(NewUiDiv(makeHtmlGetter("meh"))) - for j := 0; j < maxWsQueueLengthPerElement*10; j++ { - elem.SetInner("foo") - } + rq.NewElement(NewUiDiv(makeHtmlGetter("meh"))) } err := context.Cause(rq.ctx) - if err == nil && rq.Created.Before(deadline) { + if err == nil && rq.lastWrite.Before(deadline) { err = newErrPendingCancelledLocked(rq, newErrNoWebSocketRequest(rq)) } if err == nil { @@ -316,6 +318,54 @@ func TestJaws_CleansUpUnconnected(t *testing.T) { } } +func TestJaws_RequestWriterExtendsDeadline(t *testing.T) { + th := newTestHelper(t) + jw := New() + defer jw.Close() + var b bytes.Buffer + w := bufio.NewWriter(&b) + jw.Logger = log.New(w, "", 0) + defer jw.Close() + + hr := httptest.NewRequest(http.MethodGet, "/", nil) + rq := jw.NewRequest(hr) + rq.lastWrite = time.Now().Add(time.Second) + lastWrite := rq.lastWrite + + var sb strings.Builder + rw := rq.Writer(&sb) + + ui := &testUi{renderFn: func(e *Element, w io.Writer, params []any) error { + w.Write(nil) + return nil + }} + + rw.UI(ui) + + th.True(ui.renderCalled > 0) + th.True(rq.rendering.Load()) + th.Equal(lastWrite, rq.getLastWrite()) + + go jw.ServeWithTimeout(time.Millisecond) + + for lastWrite == rq.getLastWrite() { + select { + case <-th.C: + th.Timeout() + case <-jw.Done(): + th.Error("unexpected close") + default: + time.Sleep(time.Millisecond) + } + } + if rq.getLastWrite().IsZero() { + th.Error("last write is zero") + } + if rq.getLastWrite() == lastWrite { + th.Error("last write not modified") + } +} + func TestJaws_UnconnectedLivesUntilDeadline(t *testing.T) { th := newTestHelper(t) jw := New() @@ -325,7 +375,7 @@ func TestJaws_UnconnectedLivesUntilDeadline(t *testing.T) { rq1 := jw.NewRequest(hr) rq1ctx := rq1.Context() rq2 := jw.NewRequest(hr) - rq2.Created = time.Now().Add(-time.Second * 10) + rq2.lastWrite = time.Now().Add(-time.Second * 10) rq2ctx := rq2.Context() th.Equal(jw.Pending(), 2) diff --git a/request.go b/request.go index 506711b..0e9df1a 100644 --- a/request.go +++ b/request.go @@ -31,13 +31,14 @@ type ConnectFn = func(rq *Request) error type Request struct { Jaws *Jaws // (read-only) the JaWS instance the Request belongs to JawsKey uint64 // (read-only) a random number used in the WebSocket URI to identify this Request - Created time.Time // (read-only) when the Request was created, used for automatic cleanup remoteIP netip.Addr // (read-only) remote IP, or nil + rendering atomic.Bool // set to true by RequestWriter.Write() + running atomic.Bool // if ServeHTTP() is running + claimed atomic.Bool // if UseRequest() has been called for it mu deadlock.RWMutex // protects following + lastWrite time.Time // when the initial HTML was last written to, used for automatic cleanup initial *http.Request // initial HTTP request passed to Jaws.NewRequest session *Session // session, if established - claimed bool // if UseRequest() has been called for it - running bool // if ServeHTTP() is running todoDirt []any // dirty tags ctx context.Context // current context, derived from either Jaws or WS HTTP req cancelFn context.CancelCauseFunc // cancel function @@ -53,8 +54,6 @@ type eventFnCall struct { data string } -const maxWsQueueLengthPerElement = 100 - func (rq *Request) JawsKeyString() string { jawsKey := uint64(0) if rq != nil { @@ -70,25 +69,25 @@ func (rq *Request) String() string { var ErrRequestAlreadyClaimed = errors.New("request already claimed") var ErrJavascriptDisabled = errors.New("javascript is disabled") -func (rq *Request) claim(hr *http.Request) (err error) { - rq.mu.Lock() - defer rq.mu.Unlock() - if rq.claimed { - return ErrRequestAlreadyClaimed - } - var actualIP netip.Addr - ctx := context.Background() - if hr != nil { - actualIP = parseIP(hr.RemoteAddr) - ctx = hr.Context() - } - if equalIP(rq.remoteIP, actualIP) { - rq.ctx, rq.cancelFn = context.WithCancelCause(ctx) - rq.claimed = true - } else { - err = fmt.Errorf("/jaws/%s: expected IP %q, got %q", rq.JawsKeyString(), rq.remoteIP.String(), actualIP.String()) +func (rq *Request) claim(hr *http.Request) error { + if !rq.claimed.Load() { + var actualIP netip.Addr + ctx := context.Background() + if hr != nil { + actualIP = parseIP(hr.RemoteAddr) + ctx = hr.Context() + } + rq.mu.Lock() + defer rq.mu.Unlock() + if !equalIP(rq.remoteIP, actualIP) { + return fmt.Errorf("/jaws/%s: expected IP %q, got %q", rq.JawsKeyString(), rq.remoteIP.String(), actualIP.String()) + } + if rq.claimed.CompareAndSwap(false, true) { + rq.ctx, rq.cancelFn = context.WithCancelCause(ctx) + return nil + } } - return + return ErrRequestAlreadyClaimed } func (rq *Request) killSessionLocked() { @@ -107,10 +106,10 @@ func (rq *Request) killSession() { func (rq *Request) clearLocked() *Request { rq.JawsKey = 0 rq.connectFn = nil - rq.Created = time.Time{} + rq.lastWrite = time.Time{} rq.initial = nil - rq.claimed = false - rq.running = false + rq.running.Store(false) + rq.claimed.Store(false) rq.ctx, rq.cancelFn = context.WithCancelCause(context.Background()) rq.todoDirt = rq.todoDirt[:0] rq.remoteIP = netip.Addr{} @@ -216,15 +215,22 @@ func (rq *Request) Context() (ctx context.Context) { return } -func (rq *Request) maintenance(deadline time.Time) bool { - rq.mu.RLock() - defer rq.mu.RUnlock() - if !rq.running { - if rq.ctx.Err() != nil { +func (rq *Request) maintenance(now time.Time, requestTimeout time.Duration) bool { + if !rq.running.Load() { + if rq.rendering.Swap(false) { + rq.mu.Lock() + rq.lastWrite = now + rq.mu.Unlock() + } + rq.mu.RLock() + err := rq.ctx.Err() + since := now.Sub(rq.lastWrite) + rq.mu.RUnlock() + if err != nil { return true } - if rq.Created.Before(deadline) { - rq.cancelLocked(newErrNoWebSocketRequest(rq)) + if since > requestTimeout { + rq.cancel(newErrNoWebSocketRequest(rq)) return true } } @@ -233,7 +239,7 @@ func (rq *Request) maintenance(deadline time.Time) bool { func (rq *Request) cancelLocked(err error) { if rq.JawsKey != 0 && rq.ctx.Err() == nil { - if !rq.running { + if !rq.running.Load() { err = newErrPendingCancelledLocked(rq, err) } rq.cancelFn(rq.Jaws.Log(err)) @@ -493,15 +499,6 @@ func (rq *Request) process(broadcastMsgCh chan Message, incomingMsgCh <-chan wsM elem.update() } - /*// Append pending WS messages to the queue - // in the order of Element creation. - rq.mu.RLock() - for _, elem := range rq.elems { - wsQueue = append(wsQueue, elem.wsQueue...) - elem.wsQueue = elem.wsQueue[:0] - } - rq.mu.RUnlock()*/ - rq.sendQueue(outboundMsgCh) select { @@ -604,9 +601,7 @@ func (rq *Request) handleRemove(data string) { } func (rq *Request) queue(msg wsMsg) { - // rq.mu.Lock() rq.wsQueue = append(rq.wsQueue, msg) - // rq.mu.Unlock() } func (rq *Request) callAllEventHandlers(id Jid, wht what.What, val string) (err error) { @@ -748,6 +743,13 @@ func (rq *Request) makeUpdateList() (todo []*Element) { return } +func (rq *Request) getLastWrite() (when time.Time) { + rq.mu.RLock() + when = rq.lastWrite + rq.mu.RUnlock() + return +} + // eventCaller calls event functions func (rq *Request) eventCaller(eventCallCh <-chan eventFnCall, outboundMsgCh chan<- wsMsg, eventDoneCh chan<- struct{}) { defer close(eventDoneCh) diff --git a/requestwriter.go b/requestwriter.go index 66c9a76..eda9520 100644 --- a/requestwriter.go +++ b/requestwriter.go @@ -11,7 +11,12 @@ type RequestWriter struct { } func (rw RequestWriter) UI(ui UI, params ...any) error { - return rw.rq.render(rw.rq.NewElement(ui), rw.Writer, params) + return rw.rq.render(rw.rq.NewElement(ui), rw, params) +} + +func (rw RequestWriter) Write(p []byte) (n int, err error) { + rw.rq.rendering.Store(true) + return rw.Writer.Write(p) } // Request returns the current jaws.Request. diff --git a/ui_test.go b/ui_test.go index 7b9c94a..dcc0d60 100644 --- a/ui_test.go +++ b/ui_test.go @@ -18,7 +18,6 @@ func (s testStringer) String() string { } func TestRequest_JawsRender_DebugOutput(t *testing.T) { - is := newTestHelper(t) rq := newTestRequest() defer rq.Close() diff --git a/ws.go b/ws.go index d708514..558a9f2 100644 --- a/ws.go +++ b/ws.go @@ -10,12 +10,7 @@ import ( ) func (rq *Request) startServe() (ok bool) { - rq.mu.Lock() - if ok = !rq.running && rq.claimed; ok { - rq.running = true - } - rq.mu.Unlock() - return + return rq.claimed.Load() && rq.running.CompareAndSwap(false, true) } func (rq *Request) stopServe() {