From 90b837fe5cccb95c15d0aad04afe3d12fc3062d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Krzy=C5=BCanowski?= Date: Fri, 22 Mar 2024 16:04:42 +0100 Subject: [PATCH] Added authentication --- main.go | 160 ++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 126 insertions(+), 34 deletions(-) diff --git a/main.go b/main.go index a21a824..5a3f8f0 100644 --- a/main.go +++ b/main.go @@ -11,19 +11,35 @@ import ( "time" ) +type Account struct { + nickname string + password string +} + type ServerContext struct { - peersList []Peer + peersList []*Peer peersListLock sync.RWMutex + accounts map[string]*Account + accountsLock sync.RWMutex +} + +type HandlerContext struct { + peer *Peer + srvCtx *ServerContext } type Peer struct { - id int - conn net.Conn + id int + conn net.Conn + hasAccount bool + account *Account } type PeerInfo struct { - ID int `json:"id"` - Addr string `json:"addr"` + ID int `json:"id"` + Addr string `json:"addr"` + HasNickaname bool `json:"hasNickname"` + Nickname string `json:"nickname"` } type EchoRequest struct { @@ -41,14 +57,24 @@ type ListPeersResponse struct { PeersInfo []PeerInfo `json:"peers"` } +type AuthRequest struct { + Nickname string + Password string +} + +type AuthResponse struct { + IsSuccess bool +} + const ( echoRID = 1 listPeersRID = 2 + authRID = 3 ) -func peerSliceIndexOf(s []Peer, id int) int { +func peerSliceIndexOf(s []*Peer, id int) int { i := 0 - var p Peer + var p *Peer for i, p = range s { if p.id == id { break @@ -57,28 +83,28 @@ func peerSliceIndexOf(s []Peer, id int) int { return i } -func peerSliceRemove(s *[]Peer, i int) { +func peerSliceRemove(s *[]*Peer, i int) { (*s)[i] = (*s)[len(*s)-1] *s = (*s)[:len(*s)-1] } -func handleDisconnection(srvCtx *ServerContext, id int) { - srvCtx.peersListLock.Lock() - p := srvCtx.peersList[peerSliceIndexOf(srvCtx.peersList, id)] +func handleDisconnection(handlerCtx *HandlerContext) { + handlerCtx.srvCtx.peersListLock.Lock() + p := handlerCtx.srvCtx.peersList[peerSliceIndexOf(handlerCtx.srvCtx.peersList, handlerCtx.peer.id)] log.Printf("[Server] %s disconnected\n", p.conn.RemoteAddr()) - peerSliceRemove(&srvCtx.peersList, peerSliceIndexOf(srvCtx.peersList, id)) - srvCtx.peersListLock.Unlock() + peerSliceRemove(&handlerCtx.srvCtx.peersList, peerSliceIndexOf(handlerCtx.srvCtx.peersList, handlerCtx.peer.id)) + handlerCtx.srvCtx.peersListLock.Unlock() } -func handlePeer(srvCtx *ServerContext, p Peer) { - br := bufio.NewReader(p.conn) - bw := bufio.NewWriter(p.conn) +func handlePeer(handlerCtx *HandlerContext) { + br := bufio.NewReader(handlerCtx.peer.conn) + bw := bufio.NewWriter(handlerCtx.peer.conn) for { reqBytes, err := br.ReadBytes('\n') if err == io.EOF { - handleDisconnection(srvCtx, p.id) + handleDisconnection(handlerCtx) break } else if err != nil { log.Println(err) @@ -96,9 +122,11 @@ func handlePeer(srvCtx *ServerContext, p Peer) { var resBytes []byte if operationCode == echoRID { - resBytes, err = handleEcho(srvCtx, reqJsonBytes) + resBytes, err = handleEcho(handlerCtx, reqJsonBytes) } else if operationCode == listPeersRID { - resBytes, err = handleListPeers(srvCtx, reqJsonBytes) + resBytes, err = handleListPeers(handlerCtx, reqJsonBytes) + } else if operationCode == authRID { + resBytes, err = handleAuth(handlerCtx, reqJsonBytes) } if err != nil { @@ -122,7 +150,7 @@ func handlePeer(srvCtx *ServerContext, p Peer) { } } -func handleEcho(_ *ServerContext, reqBytes []byte) (resBytes []byte, err error) { +func handleEcho(_ *HandlerContext, reqBytes []byte) (resBytes []byte, err error) { var echoReq EchoRequest err = json.Unmarshal(reqBytes, &echoReq) @@ -140,7 +168,7 @@ func handleEcho(_ *ServerContext, reqBytes []byte) (resBytes []byte, err error) return resBytes, nil } -func handleListPeers(srvCtx *ServerContext, reqBytes []byte) (resBytes []byte, err error) { +func handleListPeers(handlerCtx *HandlerContext, reqBytes []byte) (resBytes []byte, err error) { // For the sake of conciseness -> currently unmarshalling empty slice to empty struct var listPeersReq ListPeersRequest err = json.Unmarshal(reqBytes, &listPeersReq) @@ -149,16 +177,16 @@ func handleListPeers(srvCtx *ServerContext, reqBytes []byte) (resBytes []byte, e return nil, err } - srvCtx.peersListLock.RLock() - peersFreeze := make([]Peer, len(srvCtx.peersList)) - copy(peersFreeze, srvCtx.peersList) - srvCtx.peersListLock.RUnlock() + handlerCtx.srvCtx.peersListLock.RLock() + peersFreeze := make([]*Peer, len(handlerCtx.srvCtx.peersList)) + copy(peersFreeze, handlerCtx.srvCtx.peersList) + handlerCtx.srvCtx.peersListLock.RUnlock() listPeersRes := ListPeersResponse{make([]PeerInfo, 0)} for _, peer := range peersFreeze { listPeersRes.PeersInfo = append( listPeersRes.PeersInfo, - PeerInfo{peer.id, peer.conn.RemoteAddr().String()}, + PeerInfo{peer.id, peer.conn.RemoteAddr().String(), peer.hasAccount, peer.account.nickname}, ) } @@ -171,12 +199,64 @@ func handleListPeers(srvCtx *ServerContext, reqBytes []byte) (resBytes []byte, e return resBytes, nil } +func handleAuth(handlerCtx *HandlerContext, reqBytes []byte) (resBytes []byte, err error) { + var authReq AuthRequest + err = json.Unmarshal(reqBytes, &authReq) + + if err != nil { + return nil, err + } + + // Check if account already exists + handlerCtx.srvCtx.accountsLock.RLock() + account, ok := handlerCtx.srvCtx.accounts[authReq.Nickname] + handlerCtx.srvCtx.accountsLock.RUnlock() + var authRes AuthResponse + + if ok { + // Check if password matches + if authReq.Password == account.password { + authRes = AuthResponse{true} + handlerCtx.srvCtx.peersListLock.Lock() + handlerCtx.peer.hasAccount = true + handlerCtx.peer.account = account + handlerCtx.srvCtx.peersListLock.Unlock() + } else { + authRes = AuthResponse{false} + } + } else { + authRes = AuthResponse{true} + newAcc := Account{authReq.Nickname, authReq.Password} + handlerCtx.srvCtx.accountsLock.Lock() + handlerCtx.srvCtx.accounts[newAcc.nickname] = &newAcc + handlerCtx.srvCtx.accountsLock.Unlock() + handlerCtx.srvCtx.peersListLock.Lock() + handlerCtx.peer.hasAccount = true + handlerCtx.peer.account = &newAcc + handlerCtx.srvCtx.peersListLock.Unlock() + } + + resBytes, err = json.Marshal(authRes) + + if err != nil { + return nil, err + } + + return resBytes, nil +} + func printConnectedPeers(srvCtx *ServerContext) { srvCtx.peersListLock.RLock() - log.Println("[Server] Displaying all connections:") + log.Println("[Server] displaying all connections:") for _, p := range srvCtx.peersList { - log.Printf("[Server] ID#%d: %s\n", p.id, p.conn.RemoteAddr()) + nick := "-" + + if p.hasAccount { + nick = p.account.nickname + } + + log.Printf("[Server] ID#%d, Addr:%s, Auth:%t, Nick:%s\n", p.id, p.conn.RemoteAddr(), p.hasAccount, nick) } srvCtx.peersListLock.RUnlock() @@ -184,7 +264,8 @@ func printConnectedPeers(srvCtx *ServerContext) { func runServer() { idCounter := 0 - srvCtx := &ServerContext{peersList: make([]Peer, 0)} + srvCtx := &ServerContext{peersList: make([]*Peer, 0), accounts: make(map[string]*Account)} + srvCtx.accounts["xd"] = &Account{"xd", "XD"} ln, err := net.Listen("tcp", ":8080") if err != nil { @@ -208,11 +289,11 @@ func runServer() { log.Printf("[Server] client connected %s\n", c.RemoteAddr()) idCounter++ - newPeer := Peer{idCounter, c} + newPeer := Peer{idCounter, c, false, nil} srvCtx.peersListLock.Lock() - srvCtx.peersList = append(srvCtx.peersList, newPeer) + srvCtx.peersList = append(srvCtx.peersList, &newPeer) srvCtx.peersListLock.Unlock() - go handlePeer(srvCtx, newPeer) + go handlePeer(&HandlerContext{&newPeer, srvCtx}) } } @@ -245,6 +326,17 @@ func runClient() { json.Unmarshal(resBytes, &echoRes) log.Printf("[Client] echo sent (5), got %d\n", echoRes.EchoByte) + authReq := AuthRequest{"maciek", "9maciek1"} + reqBytes, _ = json.Marshal(authReq) + bw.WriteByte(authRID) + bw.Write(reqBytes) + bw.WriteByte('\n') + bw.Flush() + resBytes, _ = br.ReadBytes('\n') + var authRes AuthResponse + json.Unmarshal(resBytes, &authRes) + log.Printf("[Client] authenticated: %t\n", authRes.IsSuccess) + listReq := ListPeersRequest{} reqBytes, _ = json.Marshal(listReq) bw.WriteByte(listPeersRID) @@ -257,10 +349,10 @@ func runClient() { log.Println("[Client] printing all peers:") for _, peer := range listRes.PeersInfo { - log.Printf("[Client] Peer#%d from %s", peer.ID, peer.Addr) + log.Printf("[Client] Peer#%d from %s, hasNick: %t, nick: %s", peer.ID, peer.Addr, peer.HasNickaname, peer.Nickname) } - time.Sleep(time.Second * 5) + time.Sleep(time.Second * 10) } func main() {