Refactored and added errgroup to handle goroutines errors

This commit is contained in:
Maciej Krzyżanowski 2024-03-29 16:53:10 +01:00
parent f16b53278b
commit 9c0e3e607f
5 changed files with 128 additions and 46 deletions

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"golang.org/x/sync/errgroup"
"math/rand" "math/rand"
"net/url" "net/url"
"os" "os"
@ -76,7 +77,7 @@ func (cliCtx *Context) serverReader() error {
logger.Debug("frame read", "id", rFrame.ID) logger.Debug("frame read", "id", rFrame.ID)
if rFrame.ID > 128 { if rFrame.IsResponse() {
cliCtx.resFromServer <- rFrame cliCtx.resFromServer <- rFrame
} else { } else {
cliCtx.reqFromServer <- rFrame cliCtx.reqFromServer <- rFrame
@ -116,38 +117,65 @@ func init() {
func testAuth(ctx *Context) { func testAuth(ctx *Context) {
logger.Info("Trying to authenticate as krzmaciek...") logger.Info("Trying to authenticate as krzmaciek...")
ctx.sendRequest(cm.AuthRequest{Nickname: "krzmaciek", Password: "9maciek1"}) err := ctx.sendRequest(cm.AuthRequest{Nickname: "krzmaciek", Password: "9maciek1"})
if err != nil {
logger.Error(err)
return
}
logger.Debug("Request sent, waiting for response...") logger.Debug("Request sent, waiting for response...")
arf := ctx.getResponseFrame() arf := ctx.getResponseFrame()
ar, err := cm.ResponseFromFrame[cm.AuthResponse](arf) ar, err := cm.ResponseFromFrame[cm.AuthResponse](arf)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return
} }
logger.Infof("Authenticated?: %t", ar.IsSuccess) logger.Infof("Authenticated?: %t", ar.IsSuccess)
} }
func testEcho(ctx *Context) { func testEcho(ctx *Context) {
echoByte := rand.Intn(32) echoByte := rand.Intn(32)
logger.Info("Testing echo...", "echoByte", echoByte) logger.Info("Testing echo...", "echoByte", echoByte)
ctx.sendRequest(cm.EchoRequest{EchoByte: byte(echoByte)}) err := ctx.sendRequest(cm.EchoRequest{EchoByte: byte(echoByte)})
if err != nil {
logger.Error(err)
return
}
logger.Debug("Request sent, waiting for response...") logger.Debug("Request sent, waiting for response...")
ereqf := ctx.getResponseFrame() ereqf := ctx.getResponseFrame()
ereq, err := cm.ResponseFromFrame[cm.EchoResponse](ereqf) ereq, err := cm.ResponseFromFrame[cm.EchoResponse](ereqf)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return
} }
logger.Info("Got response", "echoByte", ereq.EchoByte) logger.Info("Got response", "echoByte", ereq.EchoByte)
} }
func testListPeers(ctx *Context) { func testListPeers(ctx *Context) {
logger.Info("Trying to get list of peers...") logger.Info("Trying to get list of peers...")
ctx.sendRequest(cm.ListPeersRequest{}) err := ctx.sendRequest(cm.ListPeersRequest{})
if err != nil {
logger.Error(err)
return
}
logger.Debug("Request sent, waiting for response...") logger.Debug("Request sent, waiting for response...")
lpreqf := ctx.getResponseFrame() lpreqf := ctx.getResponseFrame()
lpreq, err := cm.ResponseFromFrame[cm.ListPeersResponse](lpreqf) lpreq, err := cm.ResponseFromFrame[cm.ListPeersResponse](lpreqf)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return
} }
logger.Info("Got that list", "peersList", lpreq.PeersInfo) logger.Info("Got that list", "peersList", lpreq.PeersInfo)
} }
@ -160,18 +188,32 @@ func RunClient() {
return return
} }
defer c.Close() defer func(c *websocket.Conn) {
err := c.Close()
if err != nil {
logger.Error(err)
}
}(c)
ctx := NewClientContext(c) ctx := NewClientContext(c)
go ctx.serverHandler() errGroup := new(errgroup.Group)
go ctx.serverReader() errGroup.Go(ctx.serverHandler)
go ctx.serverWriter() errGroup.Go(ctx.serverReader)
errGroup.Go(ctx.serverWriter)
testAuth(ctx) testAuth(ctx)
testEcho(ctx) testEcho(ctx)
testListPeers(ctx) testListPeers(ctx)
err = errGroup.Wait()
time.Sleep(time.Second * 5) if err != nil {
logger.Info("closing connection...") logger.Error(err)
c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) }
logger.Info("closing connection...")
err = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
logger.Error(err)
}
} }

View File

@ -20,7 +20,7 @@ const (
type PeerInfo struct { type PeerInfo struct {
ID int `json:"id"` ID int `json:"id"`
Addr string `json:"addr"` Addr string `json:"addr"`
HasNickaname bool `json:"hasNickname"` HasNickname bool `json:"hasNickname"`
Nickname string `json:"nickname"` Nickname string `json:"nickname"`
} }
@ -31,6 +31,14 @@ type RFrame struct {
Rest json.RawMessage `json:"r"` Rest json.RawMessage `json:"r"`
} }
func (rf RFrame) IsRequest() bool {
return rf.ID <= 128
}
func (rf RFrame) IsResponse() bool {
return rf.ID > 128
}
func RequestFrameFrom(req Request) (RFrame, error) { func RequestFrameFrom(req Request) (RFrame, error) {
jsonBytes, err := json.Marshal(req) jsonBytes, err := json.Marshal(req)

1
go.mod
View File

@ -20,5 +20,6 @@ require (
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/net v0.21.0 // indirect golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.18.0 // indirect golang.org/x/sys v0.18.0 // indirect
) )

2
go.sum
View File

@ -29,6 +29,8 @@ golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

View File

@ -2,6 +2,7 @@ package server
import ( import (
"encoding/json" "encoding/json"
"golang.org/x/sync/errgroup"
"net/http" "net/http"
"os" "os"
"strings" "strings"
@ -46,9 +47,7 @@ func NewHandlerContext(peer *Peer, srvCtx *Context) *HandlerContext {
} }
} }
func (hdlCtx *HandlerContext) clientHandler(hdlWg *sync.WaitGroup) error { func (hdlCtx *HandlerContext) clientHandler() error {
defer hdlWg.Done()
for { for {
reqFrame := <-hdlCtx.reqFromClient reqFrame := <-hdlCtx.reqFromClient
var res common.Response var res common.Response
@ -68,6 +67,7 @@ func (hdlCtx *HandlerContext) clientHandler(hdlWg *sync.WaitGroup) error {
} }
resFrame, err := common.ResponseFrameFrom(res) resFrame, err := common.ResponseFrameFrom(res)
if err != nil { if err != nil {
logger.Errorf("could not create frame from response") logger.Errorf("could not create frame from response")
return err return err
@ -77,12 +77,11 @@ func (hdlCtx *HandlerContext) clientHandler(hdlWg *sync.WaitGroup) error {
} }
} }
func (hdlCtx *HandlerContext) clientWriter(hdlWg *sync.WaitGroup) error { func (hdlCtx *HandlerContext) clientWriter() error {
defer hdlWg.Done()
for { for {
rFrame := <-hdlCtx.rToClient rFrame := <-hdlCtx.rToClient
resJsonBytes, err := json.Marshal(rFrame) resJsonBytes, err := json.Marshal(rFrame)
if err != nil { if err != nil {
logger.Errorf("error marshalling frame to json") logger.Errorf("error marshalling frame to json")
return err return err
@ -90,6 +89,7 @@ func (hdlCtx *HandlerContext) clientWriter(hdlWg *sync.WaitGroup) error {
logger.Debugf("sending %s", string(resJsonBytes)) logger.Debugf("sending %s", string(resJsonBytes))
err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJsonBytes) err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJsonBytes)
if err != nil { if err != nil {
logger.Errorf("error writing rframe") logger.Errorf("error writing rframe")
return err return err
@ -97,11 +97,10 @@ func (hdlCtx *HandlerContext) clientWriter(hdlWg *sync.WaitGroup) error {
} }
} }
func (hdlCtx *HandlerContext) clientReader(hdlWg *sync.WaitGroup) error { func (hdlCtx *HandlerContext) clientReader() error {
defer hdlWg.Done()
for { for {
messType, messBytes, err := hdlCtx.peer.conn.ReadMessage() messType, messBytes, err := hdlCtx.peer.conn.ReadMessage()
if err != nil { if err != nil {
return err return err
} }
@ -116,10 +115,15 @@ func (hdlCtx *HandlerContext) clientReader(hdlWg *sync.WaitGroup) error {
logger.Debugf("got message text: %s", strings.Trim(string(messBytes), "\n")) logger.Debugf("got message text: %s", strings.Trim(string(messBytes), "\n"))
var rFrame common.RFrame var rFrame common.RFrame
json.Unmarshal(messBytes, &rFrame) err = json.Unmarshal(messBytes, &rFrame)
if err != nil {
return err
}
logger.Debugf("unmarshalled request frame (ID=%d)", rFrame.ID) logger.Debugf("unmarshalled request frame (ID=%d)", rFrame.ID)
if rFrame.ID > 128 { if rFrame.IsResponse() {
logger.Debug("it is response frame", "id", rFrame.ID) logger.Debug("it is response frame", "id", rFrame.ID)
hdlCtx.resFromClient <- rFrame hdlCtx.resFromClient <- rFrame
} else { } else {
@ -131,6 +135,7 @@ func (hdlCtx *HandlerContext) clientReader(hdlWg *sync.WaitGroup) error {
func (hdlCtx *HandlerContext) sendRequest(req common.Request) error { func (hdlCtx *HandlerContext) sendRequest(req common.Request) error {
rf, err := common.RequestFrameFrom(req) rf, err := common.RequestFrameFrom(req)
if err != nil { if err != nil {
return err return err
} }
@ -171,11 +176,13 @@ func NewPeer(conn *websocket.Conn) *Peer {
func peerSliceIndexOf(s []*Peer, id int) int { func peerSliceIndexOf(s []*Peer, id int) int {
i := 0 i := 0
var p *Peer var p *Peer
for i, p = range s { for i, p = range s {
if p.id == id { if p.id == id {
break break
} }
} }
return i return i
} }
@ -197,6 +204,7 @@ func handleDisconnection(handlerCtx *HandlerContext) {
func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Response, err error) { func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Response, err error) {
echoReq, err := common.RequestFromFrame[common.EchoRequest](*reqFrame) echoReq, err := common.RequestFromFrame[common.EchoRequest](*reqFrame)
if err != nil { if err != nil {
logger.Error("could not read request from frame") logger.Error("could not read request from frame")
return nil, err return nil, err
@ -209,6 +217,7 @@ func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Re
func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res common.Response, err error) { func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res common.Response, err error) {
// Currently list peers request is empty, so we can ignore it - we won't use it // Currently list peers request is empty, so we can ignore it - we won't use it
_, err = common.RequestFromFrame[common.ListPeersRequest](*reqFrame) _, err = common.RequestFromFrame[common.ListPeersRequest](*reqFrame)
if err != nil { if err != nil {
logger.Error("could not read request from frame") logger.Error("could not read request from frame")
return nil, err return nil, err
@ -226,7 +235,7 @@ func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res comm
common.PeerInfo{ common.PeerInfo{
ID: peer.id, ID: peer.id,
Addr: peer.conn.RemoteAddr().String(), Addr: peer.conn.RemoteAddr().String(),
HasNickaname: peer.hasAccount, HasNickname: peer.hasAccount,
Nickname: peer.account.nickname, Nickname: peer.account.nickname,
}, },
) )
@ -237,6 +246,7 @@ func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res comm
func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Response, err error) { func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Response, err error) {
authReq, err := common.RequestFromFrame[common.AuthRequest](*reqFrame) authReq, err := common.RequestFromFrame[common.AuthRequest](*reqFrame)
if err != nil { if err != nil {
logger.Error("could not read request from frame") logger.Error("could not read request from frame")
return nil, err return nil, err
@ -307,9 +317,26 @@ func (srvCtx *Context) addPeer(peer *Peer) {
srvCtx.peersListLock.Unlock() srvCtx.peersListLock.Unlock()
} }
func testEcho(hdlCtx *HandlerContext) {
logger.Debug("sending echo request...")
_ = hdlCtx.sendRequest(common.EchoRequest{EchoByte: 123})
logger.Debug("sent")
echoResF := hdlCtx.getResponseFrame()
logger.Debug("got response")
echoRes, err := common.ResponseFromFrame[common.EchoResponse](echoResF)
if err != nil {
logger.Error(err)
return
}
logger.Debug("test echo done", "byteSent", 123, "byteReceived", echoRes.EchoByte)
}
func (srvCtx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) { func (srvCtx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{} upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
logger.Errorf("upgrade failed") logger.Errorf("upgrade failed")
return return
@ -319,28 +346,26 @@ func (srvCtx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
srvCtx.addPeer(peer) srvCtx.addPeer(peer)
handlerCtx := NewHandlerContext(peer, srvCtx) handlerCtx := NewHandlerContext(peer, srvCtx)
defer handleDisconnection(handlerCtx) defer handleDisconnection(handlerCtx)
defer conn.Close()
defer func(conn *websocket.Conn) {
err := conn.Close()
if err != nil {
logger.Error(err)
}
}(conn)
logger.Infof("%s connected", conn.RemoteAddr()) logger.Infof("%s connected", conn.RemoteAddr())
errGroup := new(errgroup.Group)
errGroup.Go(handlerCtx.clientHandler)
errGroup.Go(handlerCtx.clientWriter)
errGroup.Go(handlerCtx.clientReader)
testEcho(handlerCtx)
err = errGroup.Wait()
var handlerWg sync.WaitGroup
handlerWg.Add(3)
go handlerCtx.clientWriter(&handlerWg)
go handlerCtx.clientHandler(&handlerWg)
go handlerCtx.clientReader(&handlerWg)
logger.Debug("sending echo request...")
handlerCtx.sendRequest(common.EchoRequest{EchoByte: 123})
logger.Debug("sent")
echoResF := handlerCtx.getResponseFrame()
logger.Debug("got response")
echoRes, err := common.ResponseFromFrame[common.EchoResponse](echoResF)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return return
} }
logger.Debug("test echo done", "byteSent", 123, "byteReceived", echoRes.EchoByte)
handlerWg.Wait()
} }
func RunServer() { func RunServer() {
@ -355,5 +380,9 @@ func RunServer() {
http.HandleFunc("/wsapi", srvCtx.wsapiHandler) http.HandleFunc("/wsapi", srvCtx.wsapiHandler)
logger.Info("Starting server...") logger.Info("Starting server...")
http.ListenAndServe(":8080", nil) err := http.ListenAndServe(":8080", nil)
if err != nil {
logger.Error(err)
}
} }