From 361b84aac4a41fbab90511e5efb05147691b6c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Krzy=C5=BCanowski?= Date: Sun, 24 Mar 2024 14:09:36 +0100 Subject: [PATCH] Working websocket version --- client/client.go | 141 ++++++++++++----------- common/common.go | 139 ++++++++++++++++------ go.mod | 7 +- go.sum | 4 + server/server.go | 292 ++++++++++++++++++++++++----------------------- 5 files changed, 340 insertions(+), 243 deletions(-) diff --git a/client/client.go b/client/client.go index 295db5a..d0daac1 100644 --- a/client/client.go +++ b/client/client.go @@ -1,91 +1,100 @@ package client import ( - "bufio" - "encoding/json" "log" - "net" + "net/url" "time" + "github.com/gorilla/websocket" cm "krzyzanowski.dev/p2pchat/common" ) -type ClientContext struct { - reader *bufio.Reader - writer *bufio.Writer -} - -func perform[T cm.Request, U cm.Response](cliCtx *ClientContext, request T) (U, error) { - reqJsonBytes, err := json.Marshal(request) - - if err != nil { - return *new(U), err - } - - reqBytes := make([]byte, 0) - reqBytes = append(reqBytes, request.GetRID()) - reqBytes = append(reqBytes, reqJsonBytes...) - reqBytes = append(reqBytes, '\n') - - _, err = cliCtx.writer.Write(reqBytes) - - if err != nil { - return *new(U), err - } - - err = cliCtx.writer.Flush() - - if err != nil { - return *new(U), err - } - - resBytes, err := cliCtx.reader.ReadBytes('\n') - - if err != nil { - return *new(U), err - } - - var res U - json.Unmarshal(resBytes, &res) - return res, nil -} - func RunClient() { - conn, err := net.Dial("tcp", ":8080") + u := url.URL{Scheme: "ws", Host: ":8080", Path: "/wsapi"} + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { - log.Println("[Client] err connecting") + log.Println("[Client] could not connect to websocket") return } - defer func() { - _ = conn.Close() - }() + defer c.Close() - br := bufio.NewReader(conn) - bw := bufio.NewWriter(conn) - cliCtx := &ClientContext{br, bw} + log.Println("[Client] authenticating...") + rf, _ := cm.RequestFrameFrom(cm.AuthRequest{Nickname: "krzmaciek", Password: "9maciek1"}) + err = c.WriteJSON(rf) + if err != nil { + log.Fatalln(err) + } - log.Println("[Client] connected to server") + var authResFrame cm.ResponseFrame + err = c.ReadJSON(&authResFrame) + if err != nil { + log.Fatalln(err) + } + + authRes, err := cm.ResponseFromFrame[cm.AuthResponse](authResFrame) + if err != nil { + log.Fatalln(err) + } + + log.Printf("[Client] authentication result: %t\n", authRes.IsSuccess) time.Sleep(time.Second * 1) - echoRes, err := perform[cm.EchoRequest, cm.EchoResponse](cliCtx, cm.EchoRequest{EchoByte: 5}) - + log.Println("[Client] sending echo...") + echoByte := 123 + rf, err = cm.RequestFrameFrom(cm.EchoRequest{EchoByte: byte(echoByte)}) if err != nil { - log.Fatalln("[Client] error performing echo") + log.Fatalln(err) } - log.Printf("[Client] echo sent (5), got %d\n", echoRes) - - authRes, _ := perform[cm.AuthRequest, cm.AuthResponse](cliCtx, cm.AuthRequest{Nickname: "maciek", Password: "9maciek1"}) - log.Printf("[Client] authenticated: %t\n", authRes.IsSuccess) - - listRes, _ := perform[cm.ListPeersRequest, cm.ListPeersResponse](cliCtx, cm.ListPeersRequest{}) - log.Println("[Client] printing all peers:") - - for _, peer := range listRes.PeersInfo { - log.Printf("[Client] Peer#%d from %s, hasNick: %t, nick: %s", peer.ID, peer.Addr, peer.HasNickaname, peer.Nickname) + err = c.WriteJSON(rf) + if err != nil { + log.Fatalln(err) } - time.Sleep(time.Second * 10) + var echoResFrame cm.ResponseFrame + err = c.ReadJSON(&echoResFrame) + if err != nil { + log.Fatalln(err) + } + + echoRes, err := cm.ResponseFromFrame[cm.EchoResponse](echoResFrame) + if err != nil { + log.Fatalln(err) + } + + log.Printf("[Client] sent echo of %d, got %d in return\n", echoByte, echoRes.EchoByte) + time.Sleep(time.Second) + + log.Println("[Client] i want list of peers...") + rf, err = cm.RequestFrameFrom(cm.ListPeersRequest{}) + if err != nil { + log.Fatalln(err) + } + + err = c.WriteJSON(rf) + if err != nil { + log.Fatalln(err) + } + + var listPeersResFrame cm.ResponseFrame + err = c.ReadJSON(&listPeersResFrame) + if err != nil { + log.Fatalln(err) + } + + listPeersRes, err := cm.ResponseFromFrame[cm.ListPeersResponse](listPeersResFrame) + if err != nil { + log.Fatalln(err) + } + + log.Println("[Client] printing list of peers:") + + for _, p := range listPeersRes.PeersInfo { + log.Printf("[Client] %+v\n", p) + } + + time.Sleep(time.Second * 5) + c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) } diff --git a/common/common.go b/common/common.go index 335c67a..74c690e 100644 --- a/common/common.go +++ b/common/common.go @@ -1,10 +1,18 @@ package common -type Request interface { - GetRID() byte -} +import ( + "encoding/json" +) -type Response Request +// Constants: + +const ( + EchoRID = 1 + ListPeersRID = 2 + AuthRID = 3 +) + +// Requests & responses subtypes type PeerInfo struct { ID int `json:"id"` @@ -13,56 +21,121 @@ type PeerInfo struct { Nickname string `json:"nickname"` } +// Requests & responses: + +type RequestFrame struct { + ID int `json:"id"` + Rest json.RawMessage `json:"request"` +} + +func RequestFrameFrom(req Request) (RequestFrame, error) { + jsonBytes, err := json.Marshal(req) + + if err != nil { + return *new(RequestFrame), err + } + + return RequestFrame{req.GetRID(), jsonBytes}, nil +} + +func RequestFromFrame[T Request](reqFrame RequestFrame) (T, error) { + var req T + err := json.Unmarshal(reqFrame.Rest, &req) + + if err != nil { + return *new(T), err + } + + return req, nil +} + +type ResponseFrame struct { + ID int `json:"id"` + Rest json.RawMessage `json:"response"` +} + +func ResponseFrameFrom(res Response) (ResponseFrame, error) { + jsonBytes, err := json.Marshal(res) + + if err != nil { + return *new(ResponseFrame), err + } + + return ResponseFrame{res.GetRID(), jsonBytes}, nil +} + +func ResponseFromFrame[T Response](resFrame ResponseFrame) (T, error) { + var res T + err := json.Unmarshal(resFrame.Rest, &res) + + if err != nil { + return *new(T), err + } + + return res, nil +} + +type Request interface { + GetRID() int +} + +type Response Request + type EchoRequest struct { EchoByte byte `json:"echoByte"` } +func (EchoRequest) GetRID() int { + return EchoRID +} + type EchoResponse struct { EchoByte byte `json:"echoByte"` } +func (EchoResponse) GetRID() int { + return EchoRID +} + type ListPeersRequest struct { } +func (ListPeersRequest) GetRID() int { + return ListPeersRID +} + type ListPeersResponse struct { PeersInfo []PeerInfo `json:"peers"` } +func (ListPeersResponse) GetRID() int { + return ListPeersRID +} + type AuthRequest struct { - Nickname string - Password string + Nickname string `json:"nickname"` + Password string `json:"password"` +} + +func (req AuthRequest) MarshalJSON() ([]byte, error) { + type Alias AuthRequest + return json.Marshal(&struct { + ID int `json:"id"` + Alias + }{ + AuthRID, + Alias(req), + }) +} + +func (AuthRequest) GetRID() int { + return AuthRID } type AuthResponse struct { IsSuccess bool } -const ( - EchoRID = 1 - ListPeersRID = 2 - AuthRID = 3 -) - -func (EchoRequest) GetRID() byte { - return EchoRID -} - -func (EchoResponse) GetRID() byte { - return EchoRID -} - -func (AuthRequest) GetRID() byte { +func (AuthResponse) GetRID() int { return AuthRID } - -func (AuthResponse) GetRID() byte { - return AuthRID -} - -func (ListPeersRequest) GetRID() byte { - return ListPeersRID -} - -func (ListPeersResponse) GetRID() byte { - return ListPeersRID -} diff --git a/go.mod b/go.mod index 9b1f977..5d4b78e 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,9 @@ module krzyzanowski.dev/p2pchat go 1.21.7 -require golang.org/x/crypto v0.21.0 +require ( + github.com/gorilla/websocket v1.5.1 + golang.org/x/crypto v0.21.0 +) + +require golang.org/x/net v0.21.0 // indirect diff --git a/go.sum b/go.sum index e02e133..fe24a7c 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= diff --git a/server/server.go b/server/server.go index 4135163..26a10fc 100644 --- a/server/server.go +++ b/server/server.go @@ -1,14 +1,14 @@ package server import ( - "bufio" "encoding/json" - "io" "log" - "net" + "net/http" + "strings" "sync" "time" + "github.com/gorilla/websocket" "golang.org/x/crypto/bcrypt" "krzyzanowski.dev/p2pchat/common" ) @@ -19,6 +19,8 @@ type Account struct { } type ServerContext struct { + idCounter int + idCounterLock sync.RWMutex peersList []*Peer peersListLock sync.RWMutex accounts map[string]*Account @@ -26,17 +28,21 @@ type ServerContext struct { } type HandlerContext struct { - peer *Peer - srvCtx *ServerContext + peer *Peer + *ServerContext } type Peer struct { id int - conn net.Conn + conn *websocket.Conn hasAccount bool account *Account } +func NewPeer(conn *websocket.Conn) *Peer { + return &Peer{-1, conn, false, nil} +} + func peerSliceIndexOf(s []*Peer, id int) int { i := 0 var p *Peer @@ -53,98 +59,40 @@ func peerSliceRemove(s *[]*Peer, i int) { *s = (*s)[:len(*s)-1] } +func (srvCtx *ServerContext) removePeer(peer *Peer) { + srvCtx.peersListLock.Lock() + peerSliceRemove(&srvCtx.peersList, peerSliceIndexOf(srvCtx.peersList, peer.id)) + srvCtx.peersListLock.Unlock() +} + func handleDisconnection(handlerCtx *HandlerContext) { - handlerCtx.srvCtx.peersListLock.Lock() - p := handlerCtx.srvCtx.peersList[peerSliceIndexOf(handlerCtx.srvCtx.peersList, handlerCtx.peer.id)] - log.Printf("[Server] %s disconnected\n", p.conn.RemoteAddr()) - peerSliceRemove(&handlerCtx.srvCtx.peersList, peerSliceIndexOf(handlerCtx.srvCtx.peersList, handlerCtx.peer.id)) - handlerCtx.srvCtx.peersListLock.Unlock() + handlerCtx.removePeer(handlerCtx.peer) + log.Printf("[Server] %s disconnected\n", handlerCtx.peer.conn.RemoteAddr()) } -func handlePeer(handlerCtx *HandlerContext) { - br := bufio.NewReader(handlerCtx.peer.conn) - bw := bufio.NewWriter(handlerCtx.peer.conn) - - for { - reqBytes, err := br.ReadBytes('\n') - - if err == io.EOF { - handleDisconnection(handlerCtx) - break - } else if err != nil { - log.Println(err) - break - } - - if len(reqBytes) <= 1 { - log.Println("got request without id") - break - } - - reqBytes = reqBytes[:len(reqBytes)-1] - operationCode := reqBytes[0] - reqJsonBytes := reqBytes[1:] - var resBytes []byte - - if operationCode == common.EchoRID { - resBytes, err = handleEcho(handlerCtx, reqJsonBytes) - } else if operationCode == common.ListPeersRID { - resBytes, err = handleListPeers(handlerCtx, reqJsonBytes) - } else if operationCode == common.AuthRID { - resBytes, err = handleAuth(handlerCtx, reqJsonBytes) - } - - if err != nil { - log.Println(err) - continue - } - - resBytes = append(resBytes, '\n') - _, err = bw.Write(resBytes) - - if err != nil { - log.Println(err) - continue - } - - err = bw.Flush() - - if err != nil { - log.Println(err) - } - } -} - -func handleEcho(_ *HandlerContext, reqBytes []byte) (resBytes []byte, err error) { - var echoReq common.EchoRequest - err = json.Unmarshal(reqBytes, &echoReq) - +func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RequestFrame) (res common.Response, err error) { + echoReq, err := common.RequestFromFrame[common.EchoRequest](*reqFrame) if err != nil { + log.Println("[Server] could not read request from frame") return nil, err } echoRes := common.EchoResponse(echoReq) - resBytes, err = json.Marshal(echoRes) - - if err != nil { - return nil, err - } - - return resBytes, nil + return echoRes, nil } -func handleListPeers(handlerCtx *HandlerContext, reqBytes []byte) (resBytes []byte, err error) { - var listPeersReq common.ListPeersRequest - err = json.Unmarshal(reqBytes, &listPeersReq) - +func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RequestFrame) (res common.Response, err error) { + // Currently list peers request is empty, so we can ignore it - we won't use it + _, err = common.RequestFromFrame[common.ListPeersRequest](*reqFrame) if err != nil { + log.Println("[Server] could not read request from frame") return nil, err } - handlerCtx.srvCtx.peersListLock.RLock() - peersFreeze := make([]*Peer, len(handlerCtx.srvCtx.peersList)) - copy(peersFreeze, handlerCtx.srvCtx.peersList) - handlerCtx.srvCtx.peersListLock.RUnlock() + hdlCtx.peersListLock.RLock() + peersFreeze := make([]*Peer, len(hdlCtx.peersList)) + copy(peersFreeze, hdlCtx.peersList) + hdlCtx.peersListLock.RUnlock() listPeersRes := common.ListPeersResponse{PeersInfo: make([]common.PeerInfo, 0)} for _, peer := range peersFreeze { @@ -159,68 +107,55 @@ func handleListPeers(handlerCtx *HandlerContext, reqBytes []byte) (resBytes []by ) } - resBytes, err = json.Marshal(listPeersRes) - - if err != nil { - return nil, err - } - - return resBytes, nil + return listPeersRes, nil } -func handleAuth(handlerCtx *HandlerContext, reqBytes []byte) (resBytes []byte, err error) { - var authReq common.AuthRequest - err = json.Unmarshal(reqBytes, &authReq) - +func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RequestFrame) (res common.Response, err error) { + authReq, err := common.RequestFromFrame[common.AuthRequest](*reqFrame) if err != nil { + log.Println("[Server] could not read request from frame") return nil, err } // Check if account already exists - handlerCtx.srvCtx.accountsLock.RLock() - account, ok := handlerCtx.srvCtx.accounts[authReq.Nickname] - handlerCtx.srvCtx.accountsLock.RUnlock() - var authRes common.AuthResponse + hdlCtx.accountsLock.RLock() + account, ok := hdlCtx.accounts[authReq.Nickname] + hdlCtx.accountsLock.RUnlock() + var authRes *common.AuthResponse if ok { // Check if password matches if bcrypt.CompareHashAndPassword(account.passHash, []byte(authReq.Password)) == nil { - authRes = common.AuthResponse{IsSuccess: true} - handlerCtx.srvCtx.peersListLock.Lock() - handlerCtx.peer.hasAccount = true - handlerCtx.peer.account = account - handlerCtx.srvCtx.peersListLock.Unlock() + authRes = &common.AuthResponse{IsSuccess: true} + hdlCtx.peersListLock.Lock() + hdlCtx.peer.hasAccount = true + hdlCtx.peer.account = account + hdlCtx.peersListLock.Unlock() } else { - authRes = common.AuthResponse{IsSuccess: false} + authRes = &common.AuthResponse{IsSuccess: false} } } else { - authRes = common.AuthResponse{IsSuccess: true} + authRes = &common.AuthResponse{IsSuccess: true} passHash, err := bcrypt.GenerateFromPassword([]byte(authReq.Password), bcrypt.DefaultCost) if err != nil { - authRes = common.AuthResponse{IsSuccess: false} + authRes = &common.AuthResponse{IsSuccess: false} } else { newAcc := Account{authReq.Nickname, passHash} - handlerCtx.srvCtx.accountsLock.Lock() - handlerCtx.srvCtx.accounts[newAcc.nickname] = &newAcc - handlerCtx.srvCtx.accountsLock.Unlock() - handlerCtx.srvCtx.peersListLock.Lock() - handlerCtx.peer.hasAccount = true - handlerCtx.peer.account = &newAcc - handlerCtx.srvCtx.peersListLock.Unlock() + hdlCtx.accountsLock.Lock() + hdlCtx.accounts[newAcc.nickname] = &newAcc + hdlCtx.accountsLock.Unlock() + hdlCtx.peersListLock.Lock() + hdlCtx.peer.hasAccount = true + hdlCtx.peer.account = &newAcc + hdlCtx.peersListLock.Unlock() } } - resBytes, err = json.Marshal(authRes) - - if err != nil { - return nil, err - } - - return resBytes, nil + return authRes, nil } -func printConnectedPeers(srvCtx *ServerContext) { +func (srvCtx *ServerContext) printConnectedPeers() { srvCtx.peersListLock.RLock() log.Println("[Server] displaying all connections:") @@ -237,36 +172,107 @@ func printConnectedPeers(srvCtx *ServerContext) { srvCtx.peersListLock.RUnlock() } -func RunServer() { - idCounter := 0 - srvCtx := &ServerContext{peersList: make([]*Peer, 0), accounts: make(map[string]*Account)} - ln, err := net.Listen("tcp", ":8080") +func (hdlCtx *HandlerContext) handleRequest(reqJsonBytes []byte) error { + log.Printf("[Server] got message text: %s\n", strings.Trim(string(reqJsonBytes), "\n")) + var reqFrame common.RequestFrame + json.Unmarshal(reqJsonBytes, &reqFrame) + log.Printf("[Server] unmarshalled request frame (ID=%d)\n", reqFrame.ID) + var res common.Response + var err error - if err != nil { - log.Println(err) + if reqFrame.ID == common.AuthRID { + res, err = hdlCtx.handleAuth(&reqFrame) + } else if reqFrame.ID == common.ListPeersRID { + res, err = hdlCtx.handleListPeers(&reqFrame) + } else if reqFrame.ID == common.EchoRID { + res, err = hdlCtx.handleEcho(&reqFrame) } - go func() { - for { - printConnectedPeers(srvCtx) - time.Sleep(time.Second * 5) - } - }() + if err != nil { + log.Printf("[Server] could not handle request ID=%d\n", reqFrame.ID) + return err + } + + resFrame, err := common.ResponseFrameFrom(res) + if err != nil { + log.Println("[Server] could not create frame from response") + return err + } + + resJsonBytes, err := json.Marshal(resFrame) + if err != nil { + log.Println("[Server] error marshalling frame to json") + return err + } + + log.Printf("[Server] sending %s\n", string(resJsonBytes)) + err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJsonBytes) + if err != nil { + log.Println("[Server] error writing response frame") + return err + } + + return nil +} + +func (srvCtx *ServerContext) addPeer(peer *Peer) { + srvCtx.idCounterLock.Lock() + srvCtx.idCounter++ + peer.id = srvCtx.idCounter + srvCtx.idCounterLock.Unlock() + srvCtx.peersListLock.Lock() + srvCtx.peersList = append(srvCtx.peersList, peer) + srvCtx.peersListLock.Unlock() +} + +func (srvCtx *ServerContext) wsapiHandler(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println("[Server] upgrade failed") + return + } + + peer := NewPeer(conn) + srvCtx.addPeer(peer) + handlerCtx := &HandlerContext{peer, srvCtx} + defer handleDisconnection(handlerCtx) + defer conn.Close() + log.Printf("[Server] %s connected\n", conn.RemoteAddr()) for { - c, err := ln.Accept() + messType, messBytes, err := conn.ReadMessage() + if err != nil { + break + } + if messType != 1 { + err := conn.WriteMessage(websocket.CloseUnsupportedData, []byte("Only JSON text is supported")) + if err != nil { + log.Println("[Server] error sending close message due to unsupported data") + } + + return + } + + err = handlerCtx.handleRequest(messBytes) if err != nil { log.Println(err) break } - - log.Printf("[Server] client connected %s\n", c.RemoteAddr()) - idCounter++ - newPeer := Peer{idCounter, c, false, nil} - srvCtx.peersListLock.Lock() - srvCtx.peersList = append(srvCtx.peersList, &newPeer) - srvCtx.peersListLock.Unlock() - go handlePeer(&HandlerContext{&newPeer, srvCtx}) } } + +func RunServer() { + srvCtx := &ServerContext{peersList: make([]*Peer, 0), accounts: make(map[string]*Account)} + + go func() { + for { + srvCtx.printConnectedPeers() + time.Sleep(time.Second * 5) + } + }() + + http.HandleFunc("/wsapi", srvCtx.wsapiHandler) + http.ListenAndServe(":8080", nil) +}