Refactor server.go

This commit is contained in:
Maciej Krzyżanowski 2024-06-21 13:41:10 +02:00
parent 2468c6d077
commit f7dc487546

View File

@ -1,3 +1,4 @@
// Package server includes functions and structures which allow to run arhcat connection initiating server
package server package server
import ( import (
@ -21,70 +22,60 @@ import (
"krzyzanowski.dev/archat/common" "krzyzanowski.dev/archat/common"
) )
type Peer struct { type peer struct {
id int id int
conn *websocket.Conn conn *websocket.Conn
hasAccount bool hasAccount bool
account *Account account *account
} }
func NewPeer(conn *websocket.Conn) *Peer { func newPeer(conn *websocket.Conn) *peer {
return &Peer{-1, conn, false, nil} return &peer{-1, conn, false, nil}
} }
func (p *Peer) NicknameOrEmpty() string { func (p *peer) NicknameOrEmpty() string {
if p.hasAccount { if p.hasAccount {
return p.account.nickname return p.account.nickname
} else {
return ""
} }
return ""
} }
type Account struct { type account struct {
nickname string nickname string
passHash []byte passHash []byte
} }
func NewInitiation(abA string, abB string) *common.Initiation { func newInitiation(abA string, abB string) *common.Initiation {
return &common.Initiation{AbANick: abA, AbBNick: abB, Stage: common.InitiationStageA} return &common.Initiation{AbANick: abA, AbBNick: abB, Stage: common.InitiationStageA}
} }
type Context struct { type serverContext struct {
idCounter int idCounter int
idCounterLock sync.RWMutex idCounterLock sync.RWMutex
peersList []*Peer peersList []*peer
peersListLock sync.RWMutex peersListLock sync.RWMutex
accounts map[string]*Account accounts map[string]*account
accountsLock sync.RWMutex accountsLock sync.RWMutex
initiations []*common.Initiation initiations []*common.Initiation
initiationsLock sync.RWMutex initiationsLock sync.RWMutex
handlerContexts []*HandlerContext handlerContexts []*handlerContext
handlerContextsLock sync.RWMutex handlerContextsLock sync.RWMutex
} }
func NewContext() *Context { func newContext() *serverContext {
return &Context{ return &serverContext{
peersList: make([]*Peer, 0), peersList: make([]*peer, 0),
accounts: make(map[string]*Account), accounts: make(map[string]*account),
initiations: make([]*common.Initiation, 0), initiations: make([]*common.Initiation, 0),
} }
} }
// Remember to lock before calling // Remember to lock before calling
func (ctx *Context) getPeerByNick(nick string) (*Peer, error) { func (ctx *serverContext) getCtxByNick(nick string) (*handlerContext, error) {
for _, peer := range ctx.peersList { idx := slices.IndexFunc(
if peer.hasAccount && peer.account.nickname == nick {
return peer, nil
}
}
return nil, errors.New("peer not found")
}
// Remember to lock before calling
func (ctx *Context) getCtxByNick(nick string) (*HandlerContext, error) {
idx := slices.IndexFunc[[]*HandlerContext, *HandlerContext](
ctx.handlerContexts, ctx.handlerContexts,
func(handlerContext *HandlerContext) bool { func(handlerContext *handlerContext) bool {
return handlerContext.peer.hasAccount && handlerContext.peer.account.nickname == nick return handlerContext.peer.hasAccount && handlerContext.peer.account.nickname == nick
}) })
@ -95,16 +86,16 @@ func (ctx *Context) getCtxByNick(nick string) (*HandlerContext, error) {
return nil, errors.New("not found") return nil, errors.New("not found")
} }
type HandlerContext struct { type handlerContext struct {
peer *Peer peer *peer
*Context *serverContext
resFromClient chan common.RFrame resFromClient chan common.RFrame
reqFromClient chan common.RFrame reqFromClient chan common.RFrame
rToClient chan common.RFrame rToClient chan common.RFrame
} }
func NewHandlerContext(peer *Peer, srvCtx *Context) *HandlerContext { func newHandlerContext(peer *peer, srvCtx *serverContext) *handlerContext {
return &HandlerContext{ return &handlerContext{
peer, peer,
srvCtx, srvCtx,
make(chan common.RFrame), make(chan common.RFrame),
@ -113,7 +104,7 @@ func NewHandlerContext(peer *Peer, srvCtx *Context) *HandlerContext {
} }
} }
func (hdlCtx *HandlerContext) clientHandler(syncCtx context.Context) error { func (hdlCtx *handlerContext) clientHandler(syncCtx context.Context) error {
handleNext: handleNext:
for { for {
select { select {
@ -160,21 +151,21 @@ handleNext:
} }
} }
func (hdlCtx *HandlerContext) clientWriter(syncCtx context.Context) error { func (hdlCtx *handlerContext) clientWriter(syncCtx context.Context) error {
for { for {
select { select {
case <-syncCtx.Done(): case <-syncCtx.Done():
return nil return nil
case rFrame := <-hdlCtx.rToClient: case rFrame := <-hdlCtx.rToClient:
resJsonBytes, err := json.Marshal(rFrame) resJSONBytes, err := json.Marshal(rFrame)
if err != nil { if err != nil {
logger.Errorf("error marshalling frame to json") logger.Errorf("error marshalling frame to json")
return err return err
} }
logger.Debugf("sending %s", string(resJsonBytes)) logger.Debugf("sending %s", string(resJSONBytes))
err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJsonBytes) err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJSONBytes)
if err != nil { if err != nil {
logger.Errorf("error writing rframe") logger.Errorf("error writing rframe")
@ -184,7 +175,7 @@ func (hdlCtx *HandlerContext) clientWriter(syncCtx context.Context) error {
} }
} }
func (hdlCtx *HandlerContext) clientReader(syncCtx context.Context) error { func (hdlCtx *handlerContext) clientReader(syncCtx context.Context) error {
for { for {
select { select {
case <-syncCtx.Done(): case <-syncCtx.Done():
@ -225,7 +216,7 @@ func (hdlCtx *HandlerContext) clientReader(syncCtx context.Context) error {
} }
} }
func (hdlCtx *HandlerContext) sendRequest(req common.Request) error { func (hdlCtx *handlerContext) sendRequest(req common.Request) error {
rf, err := common.RequestFrameFrom(req) rf, err := common.RequestFrameFrom(req)
if err != nil { if err != nil {
@ -236,7 +227,7 @@ func (hdlCtx *HandlerContext) sendRequest(req common.Request) error {
return nil return nil
} }
func (hdlCtx *HandlerContext) getResponseFrame() common.RFrame { func (hdlCtx *handlerContext) getResponseFrame() common.RFrame {
return <-hdlCtx.resFromClient return <-hdlCtx.resFromClient
} }
@ -254,29 +245,28 @@ func init() {
} }
} }
type Matcher[T any] func(*T) bool func (ctx *serverContext) removePeer(peerToDelete *peer) {
func (ctx *Context) removePeer(peer *Peer) {
ctx.handlerContextsLock.Lock() ctx.handlerContextsLock.Lock()
ctx.peersListLock.Lock() ctx.peersListLock.Lock()
ctx.initiationsLock.Lock() ctx.initiationsLock.Lock()
ctx.handlerContexts = slices.DeleteFunc[[]*HandlerContext, *HandlerContext]( ctx.handlerContexts = slices.DeleteFunc(
ctx.handlerContexts, ctx.handlerContexts,
func(h *HandlerContext) bool { func(h *handlerContext) bool {
return h.peer.id == peer.id return h.peer.id == peerToDelete.id
}) })
ctx.peersList = slices.DeleteFunc[[]*Peer, *Peer]( ctx.peersList = slices.DeleteFunc(
ctx.peersList, ctx.peersList,
func(p *Peer) bool { func(p *peer) bool {
return p.id == peer.id return p.id == peerToDelete.id
}) })
ctx.initiations = slices.DeleteFunc[[]*common.Initiation, *common.Initiation]( ctx.initiations = slices.DeleteFunc(
ctx.initiations, ctx.initiations,
func(i *common.Initiation) bool { func(i *common.Initiation) bool {
return peer.hasAccount && (peer.account.nickname == i.AbANick || peer.account.nickname == i.AbBNick) return peerToDelete.hasAccount && (peerToDelete.account.nickname == i.AbANick ||
peerToDelete.account.nickname == i.AbBNick)
}) })
// TODO: Inform the other side about peer leaving // TODO: Inform the other side about peer leaving
@ -286,12 +276,12 @@ func (ctx *Context) removePeer(peer *Peer) {
ctx.initiationsLock.Unlock() ctx.initiationsLock.Unlock()
} }
func handleDisconnection(handlerCtx *HandlerContext) { func handleDisconnection(handlerCtx *handlerContext) {
handlerCtx.removePeer(handlerCtx.peer) handlerCtx.removePeer(handlerCtx.peer)
logger.Infof("%s disconnected", handlerCtx.peer.conn.RemoteAddr()) logger.Infof("%s disconnected", handlerCtx.peer.conn.RemoteAddr())
} }
func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Response, err error) { func (hdlCtx *handlerContext) handleEcho(reqFrame *common.RFrame) (res common.Response, err error) {
echoReq, err := common.RequestFromFrame[common.EchoRequest](*reqFrame) echoReq, err := common.RequestFromFrame[common.EchoRequest](*reqFrame)
if err != nil { if err != nil {
@ -303,7 +293,7 @@ func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Re
return echoRes, nil return echoRes, nil
} }
func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res common.Response, err error) { func (hdlCtx *handlerContext) handleListPeers(reqFrame *common.RFrame) (res common.Response, err error) {
// Currently list peers request is empty, so we can ignore it - we won't use it // Currently list peers request is empty, so we can ignore it - we won't use it
_, err = common.RequestFromFrame[common.ListPeersRequest](*reqFrame) _, err = common.RequestFromFrame[common.ListPeersRequest](*reqFrame)
@ -313,7 +303,7 @@ func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res comm
} }
hdlCtx.peersListLock.RLock() hdlCtx.peersListLock.RLock()
peersFreeze := make([]*Peer, len(hdlCtx.peersList)) peersFreeze := make([]*peer, len(hdlCtx.peersList))
copy(peersFreeze, hdlCtx.peersList) copy(peersFreeze, hdlCtx.peersList)
hdlCtx.peersListLock.RUnlock() hdlCtx.peersListLock.RUnlock()
listPeersRes := common.ListPeersResponse{PeersInfo: make([]common.PeerInfo, 0)} listPeersRes := common.ListPeersResponse{PeersInfo: make([]common.PeerInfo, 0)}
@ -333,7 +323,7 @@ func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res comm
return listPeersRes, nil return listPeersRes, nil
} }
func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Response, err error) { func (hdlCtx *handlerContext) handleAuth(reqFrame *common.RFrame) (res common.Response, err error) {
authReq, err := common.RequestFromFrame[common.AuthRequest](*reqFrame) authReq, err := common.RequestFromFrame[common.AuthRequest](*reqFrame)
if err != nil { if err != nil {
@ -343,17 +333,17 @@ func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Re
// Check if account already exists // Check if account already exists
hdlCtx.accountsLock.RLock() hdlCtx.accountsLock.RLock()
account, ok := hdlCtx.accounts[authReq.Nickname] acc, ok := hdlCtx.accounts[authReq.Nickname]
hdlCtx.accountsLock.RUnlock() hdlCtx.accountsLock.RUnlock()
var authRes *common.AuthResponse var authRes *common.AuthResponse
if ok { if ok {
// Check if password matches // Check if password matches
if bcrypt.CompareHashAndPassword(account.passHash, []byte(authReq.Password)) == nil { if bcrypt.CompareHashAndPassword(acc.passHash, []byte(authReq.Password)) == nil {
authRes = &common.AuthResponse{IsSuccess: true} authRes = &common.AuthResponse{IsSuccess: true}
hdlCtx.peersListLock.Lock() hdlCtx.peersListLock.Lock()
hdlCtx.peer.hasAccount = true hdlCtx.peer.hasAccount = true
hdlCtx.peer.account = account hdlCtx.peer.account = acc
hdlCtx.peersListLock.Unlock() hdlCtx.peersListLock.Unlock()
} else { } else {
authRes = &common.AuthResponse{IsSuccess: false} authRes = &common.AuthResponse{IsSuccess: false}
@ -365,7 +355,7 @@ func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Re
if err != nil { if err != nil {
authRes = &common.AuthResponse{IsSuccess: false} authRes = &common.AuthResponse{IsSuccess: false}
} else { } else {
newAcc := Account{authReq.Nickname, passHash} newAcc := account{authReq.Nickname, passHash}
hdlCtx.accountsLock.Lock() hdlCtx.accountsLock.Lock()
hdlCtx.accounts[newAcc.nickname] = &newAcc hdlCtx.accounts[newAcc.nickname] = &newAcc
hdlCtx.accountsLock.Unlock() hdlCtx.accountsLock.Unlock()
@ -379,7 +369,7 @@ func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Re
return authRes, nil return authRes, nil
} }
func (hdlCtx *HandlerContext) handleChatStartA(reqFrame *common.RFrame) (res common.Response, err error) { func (hdlCtx *handlerContext) handleChatStartA(reqFrame *common.RFrame) (res common.Response, err error) {
startChatAReq, err := common.RequestFromFrame[common.StartChatARequest](*reqFrame) startChatAReq, err := common.RequestFromFrame[common.StartChatARequest](*reqFrame)
if err != nil { if err != nil {
@ -395,7 +385,7 @@ func (hdlCtx *HandlerContext) handleChatStartA(reqFrame *common.RFrame) (res com
// initation started // initation started
hdlCtx.initiationsLock.Lock() hdlCtx.initiationsLock.Lock()
hdlCtx.initiations = append(hdlCtx.initiations, NewInitiation(hdlCtx.peer.account.nickname, startChatAReq.Nickname)) hdlCtx.initiations = append(hdlCtx.initiations, newInitiation(hdlCtx.peer.account.nickname, startChatAReq.Nickname))
hdlCtx.initiationsLock.Unlock() hdlCtx.initiationsLock.Unlock()
chatStartB := common.StartChatBRequest{ chatStartB := common.StartChatBRequest{
@ -421,7 +411,7 @@ func (hdlCtx *HandlerContext) handleChatStartA(reqFrame *common.RFrame) (res com
return nil, nil return nil, nil
} }
func (hdlCtx *HandlerContext) handleChatStartC(reqFrame *common.RFrame) (res common.Response, err error) { func (hdlCtx *handlerContext) handleChatStartC(reqFrame *common.RFrame) (res common.Response, err error) {
hdlCtx.initiationsLock.Lock() hdlCtx.initiationsLock.Lock()
startChatCReq, err := common.RequestFromFrame[common.StartChatCRequest](*reqFrame) startChatCReq, err := common.RequestFromFrame[common.StartChatCRequest](*reqFrame)
@ -507,7 +497,7 @@ func generatePunchCode() (string, error) {
return string(codeBytes), nil return string(codeBytes), nil
} }
func (ctx *Context) printDebugInfo() { func (ctx *serverContext) printDebugInfo() {
ctx.peersListLock.RLock() ctx.peersListLock.RLock()
logger.Debug("================================ server state") logger.Debug("================================ server state")
logger.Debug("displaying all connections:") logger.Debug("displaying all connections:")
@ -531,7 +521,7 @@ func (ctx *Context) printDebugInfo() {
ctx.peersListLock.RUnlock() ctx.peersListLock.RUnlock()
} }
func (ctx *Context) addPeer(peer *Peer) { func (ctx *serverContext) addPeer(peer *peer) {
ctx.idCounterLock.Lock() ctx.idCounterLock.Lock()
ctx.idCounter++ ctx.idCounter++
peer.id = ctx.idCounter peer.id = ctx.idCounter
@ -541,7 +531,7 @@ func (ctx *Context) addPeer(peer *Peer) {
ctx.peersListLock.Unlock() ctx.peersListLock.Unlock()
} }
func testEcho(hdlCtx *HandlerContext) { func testEcho(hdlCtx *handlerContext) {
logger.Debug("sending echo request...") logger.Debug("sending echo request...")
_ = hdlCtx.sendRequest(common.EchoRequest{EchoByte: 123}) _ = hdlCtx.sendRequest(common.EchoRequest{EchoByte: 123})
logger.Debug("sent") logger.Debug("sent")
@ -557,7 +547,7 @@ func testEcho(hdlCtx *HandlerContext) {
logger.Debug("test echo done", "byteSent", 123, "byteReceived", echoRes.EchoByte) logger.Debug("test echo done", "byteSent", 123, "byteReceived", echoRes.EchoByte)
} }
func (ctx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) { func (ctx *serverContext) wsapiHandler(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{} upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
@ -567,9 +557,9 @@ func (ctx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
peer := NewPeer(conn) peer := newPeer(conn)
ctx.addPeer(peer) ctx.addPeer(peer)
handlerCtx := NewHandlerContext(peer, ctx) handlerCtx := newHandlerContext(peer, ctx)
ctx.handlerContextsLock.Lock() ctx.handlerContextsLock.Lock()
ctx.handlerContexts = append(ctx.handlerContexts, handlerCtx) ctx.handlerContexts = append(ctx.handlerContexts, handlerCtx)
ctx.handlerContextsLock.Unlock() ctx.handlerContextsLock.Unlock()
@ -616,7 +606,7 @@ func (ctx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) { func (ctx *serverContext) handleUDP(data []byte, addr net.Addr) {
var punchReq common.PunchRequest var punchReq common.PunchRequest
err := json.Unmarshal(data, &punchReq) err := json.Unmarshal(data, &punchReq)
@ -627,10 +617,10 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
logger.Debugf("got punch request %+v", punchReq) logger.Debugf("got punch request %+v", punchReq)
srvCtx.initiationsLock.Lock() ctx.initiationsLock.Lock()
defer srvCtx.initiationsLock.Unlock() defer ctx.initiationsLock.Unlock()
idx := slices.IndexFunc(srvCtx.initiations, func(i *common.Initiation) bool { idx := slices.IndexFunc(ctx.initiations, func(i *common.Initiation) bool {
return i.AbAPunchCode == punchReq.PunchCode || return i.AbAPunchCode == punchReq.PunchCode ||
i.AbBPunchCode == punchReq.PunchCode i.AbBPunchCode == punchReq.PunchCode
}) })
@ -640,7 +630,7 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
return return
} }
matchedInitation := srvCtx.initiations[idx] matchedInitation := ctx.initiations[idx]
logger.Debugf("matched initiation %+v", matchedInitation) logger.Debugf("matched initiation %+v", matchedInitation)
if matchedInitation.AbAPunchCode == punchReq.PunchCode { if matchedInitation.AbAPunchCode == punchReq.PunchCode {
@ -657,10 +647,10 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
logger.Debugf("finished completing initiation %+v", matchedInitation) logger.Debugf("finished completing initiation %+v", matchedInitation)
logger.Debug("now sending peers their addresses") logger.Debug("now sending peers their addresses")
srvCtx.peersListLock.Lock() ctx.peersListLock.Lock()
defer srvCtx.peersListLock.Unlock() defer ctx.peersListLock.Unlock()
abA, err := srvCtx.getCtxByNick(matchedInitation.AbANick) abA, err := ctx.getCtxByNick(matchedInitation.AbANick)
if err != nil { if err != nil {
logger.Debug("could not finish punching, abA not found", logger.Debug("could not finish punching, abA not found",
@ -668,7 +658,7 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
return return
} }
abB, err := srvCtx.getCtxByNick(matchedInitation.AbBNick) abB, err := ctx.getCtxByNick(matchedInitation.AbBNick)
if err != nil { if err != nil {
logger.Debug("could not finish punching, abB not found", logger.Debug("could not finish punching, abB not found",
@ -676,7 +666,7 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
return return
} }
srvCtx.initiations = slices.DeleteFunc(srvCtx.initiations, ctx.initiations = slices.DeleteFunc(ctx.initiations,
func(i *common.Initiation) bool { func(i *common.Initiation) bool {
return i.AbAPunchCode == matchedInitation.AbAPunchCode || return i.AbAPunchCode == matchedInitation.AbAPunchCode ||
i.AbBPunchCode == matchedInitation.AbAPunchCode i.AbBPunchCode == matchedInitation.AbAPunchCode
@ -705,8 +695,9 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
} }
} }
// RunServer starts archat server according to settings passed in common.ServerSettings as the first argument
func RunServer(settings common.ServerSettings) { func RunServer(settings common.ServerSettings) {
srvCtx := NewContext() srvCtx := newContext()
go func() { go func() {
for { for {