Added authentication

This commit is contained in:
Maciej Krzyżanowski 2024-03-22 16:04:42 +01:00
parent 7793298474
commit 90b837fe5c

160
main.go
View File

@ -11,19 +11,35 @@ import (
"time" "time"
) )
type Account struct {
nickname string
password string
}
type ServerContext struct { type ServerContext struct {
peersList []Peer peersList []*Peer
peersListLock sync.RWMutex peersListLock sync.RWMutex
accounts map[string]*Account
accountsLock sync.RWMutex
}
type HandlerContext struct {
peer *Peer
srvCtx *ServerContext
} }
type Peer struct { type Peer struct {
id int id int
conn net.Conn conn net.Conn
hasAccount bool
account *Account
} }
type PeerInfo struct { type PeerInfo struct {
ID int `json:"id"` ID int `json:"id"`
Addr string `json:"addr"` Addr string `json:"addr"`
HasNickaname bool `json:"hasNickname"`
Nickname string `json:"nickname"`
} }
type EchoRequest struct { type EchoRequest struct {
@ -41,14 +57,24 @@ type ListPeersResponse struct {
PeersInfo []PeerInfo `json:"peers"` PeersInfo []PeerInfo `json:"peers"`
} }
type AuthRequest struct {
Nickname string
Password string
}
type AuthResponse struct {
IsSuccess bool
}
const ( const (
echoRID = 1 echoRID = 1
listPeersRID = 2 listPeersRID = 2
authRID = 3
) )
func peerSliceIndexOf(s []Peer, id int) int { func peerSliceIndexOf(s []*Peer, id int) int {
i := 0 i := 0
var p Peer var p *Peer
for i, p = range s { for i, p = range s {
if p.id == id { if p.id == id {
break break
@ -57,28 +83,28 @@ func peerSliceIndexOf(s []Peer, id int) int {
return i return i
} }
func peerSliceRemove(s *[]Peer, i int) { func peerSliceRemove(s *[]*Peer, i int) {
(*s)[i] = (*s)[len(*s)-1] (*s)[i] = (*s)[len(*s)-1]
*s = (*s)[:len(*s)-1] *s = (*s)[:len(*s)-1]
} }
func handleDisconnection(srvCtx *ServerContext, id int) { func handleDisconnection(handlerCtx *HandlerContext) {
srvCtx.peersListLock.Lock() handlerCtx.srvCtx.peersListLock.Lock()
p := srvCtx.peersList[peerSliceIndexOf(srvCtx.peersList, id)] p := handlerCtx.srvCtx.peersList[peerSliceIndexOf(handlerCtx.srvCtx.peersList, handlerCtx.peer.id)]
log.Printf("[Server] %s disconnected\n", p.conn.RemoteAddr()) log.Printf("[Server] %s disconnected\n", p.conn.RemoteAddr())
peerSliceRemove(&srvCtx.peersList, peerSliceIndexOf(srvCtx.peersList, id)) peerSliceRemove(&handlerCtx.srvCtx.peersList, peerSliceIndexOf(handlerCtx.srvCtx.peersList, handlerCtx.peer.id))
srvCtx.peersListLock.Unlock() handlerCtx.srvCtx.peersListLock.Unlock()
} }
func handlePeer(srvCtx *ServerContext, p Peer) { func handlePeer(handlerCtx *HandlerContext) {
br := bufio.NewReader(p.conn) br := bufio.NewReader(handlerCtx.peer.conn)
bw := bufio.NewWriter(p.conn) bw := bufio.NewWriter(handlerCtx.peer.conn)
for { for {
reqBytes, err := br.ReadBytes('\n') reqBytes, err := br.ReadBytes('\n')
if err == io.EOF { if err == io.EOF {
handleDisconnection(srvCtx, p.id) handleDisconnection(handlerCtx)
break break
} else if err != nil { } else if err != nil {
log.Println(err) log.Println(err)
@ -96,9 +122,11 @@ func handlePeer(srvCtx *ServerContext, p Peer) {
var resBytes []byte var resBytes []byte
if operationCode == echoRID { if operationCode == echoRID {
resBytes, err = handleEcho(srvCtx, reqJsonBytes) resBytes, err = handleEcho(handlerCtx, reqJsonBytes)
} else if operationCode == listPeersRID { } else if operationCode == listPeersRID {
resBytes, err = handleListPeers(srvCtx, reqJsonBytes) resBytes, err = handleListPeers(handlerCtx, reqJsonBytes)
} else if operationCode == authRID {
resBytes, err = handleAuth(handlerCtx, reqJsonBytes)
} }
if err != nil { if err != nil {
@ -122,7 +150,7 @@ func handlePeer(srvCtx *ServerContext, p Peer) {
} }
} }
func handleEcho(_ *ServerContext, reqBytes []byte) (resBytes []byte, err error) { func handleEcho(_ *HandlerContext, reqBytes []byte) (resBytes []byte, err error) {
var echoReq EchoRequest var echoReq EchoRequest
err = json.Unmarshal(reqBytes, &echoReq) err = json.Unmarshal(reqBytes, &echoReq)
@ -140,7 +168,7 @@ func handleEcho(_ *ServerContext, reqBytes []byte) (resBytes []byte, err error)
return resBytes, nil return resBytes, nil
} }
func handleListPeers(srvCtx *ServerContext, reqBytes []byte) (resBytes []byte, err error) { func handleListPeers(handlerCtx *HandlerContext, reqBytes []byte) (resBytes []byte, err error) {
// For the sake of conciseness -> currently unmarshalling empty slice to empty struct // For the sake of conciseness -> currently unmarshalling empty slice to empty struct
var listPeersReq ListPeersRequest var listPeersReq ListPeersRequest
err = json.Unmarshal(reqBytes, &listPeersReq) err = json.Unmarshal(reqBytes, &listPeersReq)
@ -149,16 +177,16 @@ func handleListPeers(srvCtx *ServerContext, reqBytes []byte) (resBytes []byte, e
return nil, err return nil, err
} }
srvCtx.peersListLock.RLock() handlerCtx.srvCtx.peersListLock.RLock()
peersFreeze := make([]Peer, len(srvCtx.peersList)) peersFreeze := make([]*Peer, len(handlerCtx.srvCtx.peersList))
copy(peersFreeze, srvCtx.peersList) copy(peersFreeze, handlerCtx.srvCtx.peersList)
srvCtx.peersListLock.RUnlock() handlerCtx.srvCtx.peersListLock.RUnlock()
listPeersRes := ListPeersResponse{make([]PeerInfo, 0)} listPeersRes := ListPeersResponse{make([]PeerInfo, 0)}
for _, peer := range peersFreeze { for _, peer := range peersFreeze {
listPeersRes.PeersInfo = append( listPeersRes.PeersInfo = append(
listPeersRes.PeersInfo, listPeersRes.PeersInfo,
PeerInfo{peer.id, peer.conn.RemoteAddr().String()}, PeerInfo{peer.id, peer.conn.RemoteAddr().String(), peer.hasAccount, peer.account.nickname},
) )
} }
@ -171,12 +199,64 @@ func handleListPeers(srvCtx *ServerContext, reqBytes []byte) (resBytes []byte, e
return resBytes, nil return resBytes, nil
} }
func handleAuth(handlerCtx *HandlerContext, reqBytes []byte) (resBytes []byte, err error) {
var authReq AuthRequest
err = json.Unmarshal(reqBytes, &authReq)
if err != nil {
return nil, err
}
// Check if account already exists
handlerCtx.srvCtx.accountsLock.RLock()
account, ok := handlerCtx.srvCtx.accounts[authReq.Nickname]
handlerCtx.srvCtx.accountsLock.RUnlock()
var authRes AuthResponse
if ok {
// Check if password matches
if authReq.Password == account.password {
authRes = AuthResponse{true}
handlerCtx.srvCtx.peersListLock.Lock()
handlerCtx.peer.hasAccount = true
handlerCtx.peer.account = account
handlerCtx.srvCtx.peersListLock.Unlock()
} else {
authRes = AuthResponse{false}
}
} else {
authRes = AuthResponse{true}
newAcc := Account{authReq.Nickname, authReq.Password}
handlerCtx.srvCtx.accountsLock.Lock()
handlerCtx.srvCtx.accounts[newAcc.nickname] = &newAcc
handlerCtx.srvCtx.accountsLock.Unlock()
handlerCtx.srvCtx.peersListLock.Lock()
handlerCtx.peer.hasAccount = true
handlerCtx.peer.account = &newAcc
handlerCtx.srvCtx.peersListLock.Unlock()
}
resBytes, err = json.Marshal(authRes)
if err != nil {
return nil, err
}
return resBytes, nil
}
func printConnectedPeers(srvCtx *ServerContext) { func printConnectedPeers(srvCtx *ServerContext) {
srvCtx.peersListLock.RLock() srvCtx.peersListLock.RLock()
log.Println("[Server] Displaying all connections:") log.Println("[Server] displaying all connections:")
for _, p := range srvCtx.peersList { for _, p := range srvCtx.peersList {
log.Printf("[Server] ID#%d: %s\n", p.id, p.conn.RemoteAddr()) nick := "-"
if p.hasAccount {
nick = p.account.nickname
}
log.Printf("[Server] ID#%d, Addr:%s, Auth:%t, Nick:%s\n", p.id, p.conn.RemoteAddr(), p.hasAccount, nick)
} }
srvCtx.peersListLock.RUnlock() srvCtx.peersListLock.RUnlock()
@ -184,7 +264,8 @@ func printConnectedPeers(srvCtx *ServerContext) {
func runServer() { func runServer() {
idCounter := 0 idCounter := 0
srvCtx := &ServerContext{peersList: make([]Peer, 0)} srvCtx := &ServerContext{peersList: make([]*Peer, 0), accounts: make(map[string]*Account)}
srvCtx.accounts["xd"] = &Account{"xd", "XD"}
ln, err := net.Listen("tcp", ":8080") ln, err := net.Listen("tcp", ":8080")
if err != nil { if err != nil {
@ -208,11 +289,11 @@ func runServer() {
log.Printf("[Server] client connected %s\n", c.RemoteAddr()) log.Printf("[Server] client connected %s\n", c.RemoteAddr())
idCounter++ idCounter++
newPeer := Peer{idCounter, c} newPeer := Peer{idCounter, c, false, nil}
srvCtx.peersListLock.Lock() srvCtx.peersListLock.Lock()
srvCtx.peersList = append(srvCtx.peersList, newPeer) srvCtx.peersList = append(srvCtx.peersList, &newPeer)
srvCtx.peersListLock.Unlock() srvCtx.peersListLock.Unlock()
go handlePeer(srvCtx, newPeer) go handlePeer(&HandlerContext{&newPeer, srvCtx})
} }
} }
@ -245,6 +326,17 @@ func runClient() {
json.Unmarshal(resBytes, &echoRes) json.Unmarshal(resBytes, &echoRes)
log.Printf("[Client] echo sent (5), got %d\n", echoRes.EchoByte) log.Printf("[Client] echo sent (5), got %d\n", echoRes.EchoByte)
authReq := AuthRequest{"maciek", "9maciek1"}
reqBytes, _ = json.Marshal(authReq)
bw.WriteByte(authRID)
bw.Write(reqBytes)
bw.WriteByte('\n')
bw.Flush()
resBytes, _ = br.ReadBytes('\n')
var authRes AuthResponse
json.Unmarshal(resBytes, &authRes)
log.Printf("[Client] authenticated: %t\n", authRes.IsSuccess)
listReq := ListPeersRequest{} listReq := ListPeersRequest{}
reqBytes, _ = json.Marshal(listReq) reqBytes, _ = json.Marshal(listReq)
bw.WriteByte(listPeersRID) bw.WriteByte(listPeersRID)
@ -257,10 +349,10 @@ func runClient() {
log.Println("[Client] printing all peers:") log.Println("[Client] printing all peers:")
for _, peer := range listRes.PeersInfo { for _, peer := range listRes.PeersInfo {
log.Printf("[Client] Peer#%d from %s", peer.ID, peer.Addr) log.Printf("[Client] Peer#%d from %s, hasNick: %t, nick: %s", peer.ID, peer.Addr, peer.HasNickaname, peer.Nickname)
} }
time.Sleep(time.Second * 5) time.Sleep(time.Second * 10)
} }
func main() { func main() {