Skip to content

Commit

Permalink
Integrated the websocket client into the REST client
Browse files Browse the repository at this point in the history
  • Loading branch information
tanis2000 committed Jun 6, 2024
1 parent 82701bd commit 001ed59
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 32 deletions.
6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 19 additions & 11 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,34 @@ type Client struct {
queuedItems map[string]*QueuePromptResponse
queuedCount int
callbacks *Callbacks
wsc *WebSocketClient
}

type Callbacks struct {
OnStatus func(*Client, int)
OnExecutionStart func(*Client, *QueuePromptResponse)
OnExecuted func(*Client, *QueuePromptResponse)
OnExecuting func(*Client, *QueuePromptResponse, string)
OnProgress func(*Client, *WSStatusMessageDataProgress)
}

func NewClient(serverAddress string, serverPort int, callbacks *Callbacks) *Client {
return &Client{
func NewClient(serverAddress string, serverPort int, callbacks *Callbacks) (*Client, error) {
res := &Client{
serverAddress: serverAddress,
serverPort: serverPort,
clientId: uuid.New().String(),
queuedItems: make(map[string]*QueuePromptResponse),
callbacks: callbacks,
}
res.wsc = NewWebSocketClient(res)
err := res.wsc.Connect(serverAddress, serverPort, res.clientId)
if err != nil {
return nil, err
}
go func() {
res.wsc.HandleMessages()
}()
return res, nil
}

func (c *Client) buildUrl(path string) string {
Expand All @@ -46,6 +57,7 @@ func (c *Client) OnMessage(message string) {
}

func (c *Client) OnWebSocketMessage(msg string) {
slog.Info("msg:", "msg", msg)
message := &WSStatusMessage{}
err := json.Unmarshal([]byte(msg), message)
if err != nil {
Expand Down Expand Up @@ -78,17 +90,13 @@ func (c *Client) OnWebSocketMessage(msg string) {
if c.callbacks != nil && c.callbacks.OnExecuting != nil {
c.callbacks.OnExecuting(c, qi, s.Node)
}
qi.Messages <- "executing" + s.Node
qi.Messages <- "executing " + s.Node
}
case "progress":
//s := message.Data.(*WSStatusMessageDataProgress)
//qi := c.GetQueuedItem(s.PromptID)
//if qi != nil {
// if c.callbacks != nil && c.callbacks.OnExecuting != nil {
// c.callbacks.OnExecuting(c, qi, s.Node)
// }
// qi.Messages <- "progress"
//}
s := message.Data.(*WSStatusMessageDataProgress)
if c.callbacks != nil && c.callbacks.OnProgress != nil {
c.callbacks.OnProgress(c, s)
}
case "executed":
s := message.Data.(*WSStatusMessageDataExecuted)
qi := c.GetQueuedItem(s.PromptID)
Expand Down
26 changes: 5 additions & 21 deletions examples/txt2img/txt2img.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ func main() {
log.Printf("Queue size: %d", queuedItems)
},
}
c := client.NewClient("localhost", 8188, callbacks)
c, err := client.NewClient("localhost", 8188, callbacks)
if err != nil {
panic(err)
}

println("Getting System Stats")
stats, err := c.GetSystemStats()
Expand Down Expand Up @@ -131,29 +134,10 @@ func main() {
}
println(res)

println("Starting the websocket")
wsc := client.NewWebSocketClient(c)
err = wsc.Connect("localhost", 8188, c.ClientId())
if err != nil {
panic(err)
}
println("Pinging the websocket")
err = wsc.Ping()
if err != nil {
panic(err)
}
go func() {
println("Handling messages")
wsc.HandleMessages()
}()

println("Starting the loop")
for continueLoop := true; continueLoop; {
msg := <-res.Messages
println(msg)
}
err = wsc.Close()
if err != nil {
panic(err)
}

}

0 comments on commit 001ed59

Please sign in to comment.