diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/client/client.go b/client/client.go index a0c844d..6cc58a8 100644 --- a/client/client.go +++ b/client/client.go @@ -14,6 +14,7 @@ type Client struct { queuedItems map[string]*QueuePromptResponse queuedCount int callbacks *Callbacks + wsc *WebSocketClient } type Callbacks struct { @@ -21,16 +22,26 @@ type Callbacks struct { OnExecutionStart func(*Client, *QueuePromptResponse) OnExecuted func(*Client, *QueuePromptResponse) OnExecuting func(*Client, *QueuePromptResponse, string) + OnProgress func(*Client, *WSStatusMessageDataProgress) } -func NewClient(serverAddress string, serverPort int, callbacks *Callbacks) *Client { - return &Client{ +func NewClient(serverAddress string, serverPort int, callbacks *Callbacks) (*Client, error) { + res := &Client{ serverAddress: serverAddress, serverPort: serverPort, clientId: uuid.New().String(), queuedItems: make(map[string]*QueuePromptResponse), callbacks: callbacks, } + res.wsc = NewWebSocketClient(res) + err := res.wsc.Connect(serverAddress, serverPort, res.clientId) + if err != nil { + return nil, err + } + go func() { + res.wsc.HandleMessages() + }() + return res, nil } func (c *Client) buildUrl(path string) string { @@ -46,6 +57,7 @@ func (c *Client) OnMessage(message string) { } func (c *Client) OnWebSocketMessage(msg string) { + slog.Info("msg:", "msg", msg) message := &WSStatusMessage{} err := json.Unmarshal([]byte(msg), message) if err != nil { @@ -78,17 +90,13 @@ func (c *Client) OnWebSocketMessage(msg string) { if c.callbacks != nil && c.callbacks.OnExecuting != nil { c.callbacks.OnExecuting(c, qi, s.Node) } - qi.Messages <- "executing" + s.Node + qi.Messages <- "executing " + s.Node } case "progress": - //s := message.Data.(*WSStatusMessageDataProgress) - //qi := c.GetQueuedItem(s.PromptID) - //if qi != nil { - // if c.callbacks != nil && c.callbacks.OnExecuting != nil { - // c.callbacks.OnExecuting(c, qi, s.Node) - // } - // qi.Messages <- "progress" - //} + s := message.Data.(*WSStatusMessageDataProgress) + if c.callbacks != nil && c.callbacks.OnProgress != nil { + c.callbacks.OnProgress(c, s) + } case "executed": s := message.Data.(*WSStatusMessageDataExecuted) qi := c.GetQueuedItem(s.PromptID) diff --git a/examples/txt2img/txt2img.go b/examples/txt2img/txt2img.go index 67aba85..47c7fe0 100644 --- a/examples/txt2img/txt2img.go +++ b/examples/txt2img/txt2img.go @@ -17,7 +17,10 @@ func main() { log.Printf("Queue size: %d", queuedItems) }, } - c := client.NewClient("localhost", 8188, callbacks) + c, err := client.NewClient("localhost", 8188, callbacks) + if err != nil { + panic(err) + } println("Getting System Stats") stats, err := c.GetSystemStats() @@ -131,29 +134,10 @@ func main() { } println(res) - println("Starting the websocket") - wsc := client.NewWebSocketClient(c) - err = wsc.Connect("localhost", 8188, c.ClientId()) - if err != nil { - panic(err) - } - println("Pinging the websocket") - err = wsc.Ping() - if err != nil { - panic(err) - } - go func() { - println("Handling messages") - wsc.HandleMessages() - }() - println("Starting the loop") for continueLoop := true; continueLoop; { msg := <-res.Messages println(msg) } - err = wsc.Close() - if err != nil { - panic(err) - } + }