Refactor server.go

This commit is contained in:
Maciej Krzyżanowski 2024-06-21 13:41:10 +02:00
parent 2468c6d077
commit f7dc487546

View File

@ -1,3 +1,4 @@
// Package server includes functions and structures which allow to run arhcat connection initiating server
package server
import (
@ -21,70 +22,60 @@ import (
"krzyzanowski.dev/archat/common"
)
type Peer struct {
type peer struct {
id int
conn *websocket.Conn
hasAccount bool
account *Account
account *account
}
func NewPeer(conn *websocket.Conn) *Peer {
return &Peer{-1, conn, false, nil}
func newPeer(conn *websocket.Conn) *peer {
return &peer{-1, conn, false, nil}
}
func (p *Peer) NicknameOrEmpty() string {
func (p *peer) NicknameOrEmpty() string {
if p.hasAccount {
return p.account.nickname
} else {
return ""
}
return ""
}
type Account struct {
type account struct {
nickname string
passHash []byte
}
func NewInitiation(abA string, abB string) *common.Initiation {
func newInitiation(abA string, abB string) *common.Initiation {
return &common.Initiation{AbANick: abA, AbBNick: abB, Stage: common.InitiationStageA}
}
type Context struct {
type serverContext struct {
idCounter int
idCounterLock sync.RWMutex
peersList []*Peer
peersList []*peer
peersListLock sync.RWMutex
accounts map[string]*Account
accounts map[string]*account
accountsLock sync.RWMutex
initiations []*common.Initiation
initiationsLock sync.RWMutex
handlerContexts []*HandlerContext
handlerContexts []*handlerContext
handlerContextsLock sync.RWMutex
}
func NewContext() *Context {
return &Context{
peersList: make([]*Peer, 0),
accounts: make(map[string]*Account),
func newContext() *serverContext {
return &serverContext{
peersList: make([]*peer, 0),
accounts: make(map[string]*account),
initiations: make([]*common.Initiation, 0),
}
}
// Remember to lock before calling
func (ctx *Context) getPeerByNick(nick string) (*Peer, error) {
for _, peer := range ctx.peersList {
if peer.hasAccount && peer.account.nickname == nick {
return peer, nil
}
}
return nil, errors.New("peer not found")
}
// Remember to lock before calling
func (ctx *Context) getCtxByNick(nick string) (*HandlerContext, error) {
idx := slices.IndexFunc[[]*HandlerContext, *HandlerContext](
func (ctx *serverContext) getCtxByNick(nick string) (*handlerContext, error) {
idx := slices.IndexFunc(
ctx.handlerContexts,
func(handlerContext *HandlerContext) bool {
func(handlerContext *handlerContext) bool {
return handlerContext.peer.hasAccount && handlerContext.peer.account.nickname == nick
})
@ -95,16 +86,16 @@ func (ctx *Context) getCtxByNick(nick string) (*HandlerContext, error) {
return nil, errors.New("not found")
}
type HandlerContext struct {
peer *Peer
*Context
type handlerContext struct {
peer *peer
*serverContext
resFromClient chan common.RFrame
reqFromClient chan common.RFrame
rToClient chan common.RFrame
}
func NewHandlerContext(peer *Peer, srvCtx *Context) *HandlerContext {
return &HandlerContext{
func newHandlerContext(peer *peer, srvCtx *serverContext) *handlerContext {
return &handlerContext{
peer,
srvCtx,
make(chan common.RFrame),
@ -113,7 +104,7 @@ func NewHandlerContext(peer *Peer, srvCtx *Context) *HandlerContext {
}
}
func (hdlCtx *HandlerContext) clientHandler(syncCtx context.Context) error {
func (hdlCtx *handlerContext) clientHandler(syncCtx context.Context) error {
handleNext:
for {
select {
@ -160,21 +151,21 @@ handleNext:
}
}
func (hdlCtx *HandlerContext) clientWriter(syncCtx context.Context) error {
func (hdlCtx *handlerContext) clientWriter(syncCtx context.Context) error {
for {
select {
case <-syncCtx.Done():
return nil
case rFrame := <-hdlCtx.rToClient:
resJsonBytes, err := json.Marshal(rFrame)
resJSONBytes, err := json.Marshal(rFrame)
if err != nil {
logger.Errorf("error marshalling frame to json")
return err
}
logger.Debugf("sending %s", string(resJsonBytes))
err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJsonBytes)
logger.Debugf("sending %s", string(resJSONBytes))
err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJSONBytes)
if err != nil {
logger.Errorf("error writing rframe")
@ -184,7 +175,7 @@ func (hdlCtx *HandlerContext) clientWriter(syncCtx context.Context) error {
}
}
func (hdlCtx *HandlerContext) clientReader(syncCtx context.Context) error {
func (hdlCtx *handlerContext) clientReader(syncCtx context.Context) error {
for {
select {
case <-syncCtx.Done():
@ -225,7 +216,7 @@ func (hdlCtx *HandlerContext) clientReader(syncCtx context.Context) error {
}
}
func (hdlCtx *HandlerContext) sendRequest(req common.Request) error {
func (hdlCtx *handlerContext) sendRequest(req common.Request) error {
rf, err := common.RequestFrameFrom(req)
if err != nil {
@ -236,7 +227,7 @@ func (hdlCtx *HandlerContext) sendRequest(req common.Request) error {
return nil
}
func (hdlCtx *HandlerContext) getResponseFrame() common.RFrame {
func (hdlCtx *handlerContext) getResponseFrame() common.RFrame {
return <-hdlCtx.resFromClient
}
@ -254,29 +245,28 @@ func init() {
}
}
type Matcher[T any] func(*T) bool
func (ctx *Context) removePeer(peer *Peer) {
func (ctx *serverContext) removePeer(peerToDelete *peer) {
ctx.handlerContextsLock.Lock()
ctx.peersListLock.Lock()
ctx.initiationsLock.Lock()
ctx.handlerContexts = slices.DeleteFunc[[]*HandlerContext, *HandlerContext](
ctx.handlerContexts = slices.DeleteFunc(
ctx.handlerContexts,
func(h *HandlerContext) bool {
return h.peer.id == peer.id
func(h *handlerContext) bool {
return h.peer.id == peerToDelete.id
})
ctx.peersList = slices.DeleteFunc[[]*Peer, *Peer](
ctx.peersList = slices.DeleteFunc(
ctx.peersList,
func(p *Peer) bool {
return p.id == peer.id
func(p *peer) bool {
return p.id == peerToDelete.id
})
ctx.initiations = slices.DeleteFunc[[]*common.Initiation, *common.Initiation](
ctx.initiations = slices.DeleteFunc(
ctx.initiations,
func(i *common.Initiation) bool {
return peer.hasAccount && (peer.account.nickname == i.AbANick || peer.account.nickname == i.AbBNick)
return peerToDelete.hasAccount && (peerToDelete.account.nickname == i.AbANick ||
peerToDelete.account.nickname == i.AbBNick)
})
// TODO: Inform the other side about peer leaving
@ -286,12 +276,12 @@ func (ctx *Context) removePeer(peer *Peer) {
ctx.initiationsLock.Unlock()
}
func handleDisconnection(handlerCtx *HandlerContext) {
func handleDisconnection(handlerCtx *handlerContext) {
handlerCtx.removePeer(handlerCtx.peer)
logger.Infof("%s disconnected", handlerCtx.peer.conn.RemoteAddr())
}
func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Response, err error) {
func (hdlCtx *handlerContext) handleEcho(reqFrame *common.RFrame) (res common.Response, err error) {
echoReq, err := common.RequestFromFrame[common.EchoRequest](*reqFrame)
if err != nil {
@ -303,7 +293,7 @@ func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Re
return echoRes, nil
}
func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res common.Response, err error) {
func (hdlCtx *handlerContext) handleListPeers(reqFrame *common.RFrame) (res common.Response, err error) {
// Currently list peers request is empty, so we can ignore it - we won't use it
_, err = common.RequestFromFrame[common.ListPeersRequest](*reqFrame)
@ -313,7 +303,7 @@ func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res comm
}
hdlCtx.peersListLock.RLock()
peersFreeze := make([]*Peer, len(hdlCtx.peersList))
peersFreeze := make([]*peer, len(hdlCtx.peersList))
copy(peersFreeze, hdlCtx.peersList)
hdlCtx.peersListLock.RUnlock()
listPeersRes := common.ListPeersResponse{PeersInfo: make([]common.PeerInfo, 0)}
@ -333,7 +323,7 @@ func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res comm
return listPeersRes, nil
}
func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Response, err error) {
func (hdlCtx *handlerContext) handleAuth(reqFrame *common.RFrame) (res common.Response, err error) {
authReq, err := common.RequestFromFrame[common.AuthRequest](*reqFrame)
if err != nil {
@ -343,17 +333,17 @@ func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Re
// Check if account already exists
hdlCtx.accountsLock.RLock()
account, ok := hdlCtx.accounts[authReq.Nickname]
acc, ok := hdlCtx.accounts[authReq.Nickname]
hdlCtx.accountsLock.RUnlock()
var authRes *common.AuthResponse
if ok {
// Check if password matches
if bcrypt.CompareHashAndPassword(account.passHash, []byte(authReq.Password)) == nil {
if bcrypt.CompareHashAndPassword(acc.passHash, []byte(authReq.Password)) == nil {
authRes = &common.AuthResponse{IsSuccess: true}
hdlCtx.peersListLock.Lock()
hdlCtx.peer.hasAccount = true
hdlCtx.peer.account = account
hdlCtx.peer.account = acc
hdlCtx.peersListLock.Unlock()
} else {
authRes = &common.AuthResponse{IsSuccess: false}
@ -365,7 +355,7 @@ func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Re
if err != nil {
authRes = &common.AuthResponse{IsSuccess: false}
} else {
newAcc := Account{authReq.Nickname, passHash}
newAcc := account{authReq.Nickname, passHash}
hdlCtx.accountsLock.Lock()
hdlCtx.accounts[newAcc.nickname] = &newAcc
hdlCtx.accountsLock.Unlock()
@ -379,7 +369,7 @@ func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Re
return authRes, nil
}
func (hdlCtx *HandlerContext) handleChatStartA(reqFrame *common.RFrame) (res common.Response, err error) {
func (hdlCtx *handlerContext) handleChatStartA(reqFrame *common.RFrame) (res common.Response, err error) {
startChatAReq, err := common.RequestFromFrame[common.StartChatARequest](*reqFrame)
if err != nil {
@ -395,7 +385,7 @@ func (hdlCtx *HandlerContext) handleChatStartA(reqFrame *common.RFrame) (res com
// initation started
hdlCtx.initiationsLock.Lock()
hdlCtx.initiations = append(hdlCtx.initiations, NewInitiation(hdlCtx.peer.account.nickname, startChatAReq.Nickname))
hdlCtx.initiations = append(hdlCtx.initiations, newInitiation(hdlCtx.peer.account.nickname, startChatAReq.Nickname))
hdlCtx.initiationsLock.Unlock()
chatStartB := common.StartChatBRequest{
@ -421,7 +411,7 @@ func (hdlCtx *HandlerContext) handleChatStartA(reqFrame *common.RFrame) (res com
return nil, nil
}
func (hdlCtx *HandlerContext) handleChatStartC(reqFrame *common.RFrame) (res common.Response, err error) {
func (hdlCtx *handlerContext) handleChatStartC(reqFrame *common.RFrame) (res common.Response, err error) {
hdlCtx.initiationsLock.Lock()
startChatCReq, err := common.RequestFromFrame[common.StartChatCRequest](*reqFrame)
@ -507,7 +497,7 @@ func generatePunchCode() (string, error) {
return string(codeBytes), nil
}
func (ctx *Context) printDebugInfo() {
func (ctx *serverContext) printDebugInfo() {
ctx.peersListLock.RLock()
logger.Debug("================================ server state")
logger.Debug("displaying all connections:")
@ -531,7 +521,7 @@ func (ctx *Context) printDebugInfo() {
ctx.peersListLock.RUnlock()
}
func (ctx *Context) addPeer(peer *Peer) {
func (ctx *serverContext) addPeer(peer *peer) {
ctx.idCounterLock.Lock()
ctx.idCounter++
peer.id = ctx.idCounter
@ -541,7 +531,7 @@ func (ctx *Context) addPeer(peer *Peer) {
ctx.peersListLock.Unlock()
}
func testEcho(hdlCtx *HandlerContext) {
func testEcho(hdlCtx *handlerContext) {
logger.Debug("sending echo request...")
_ = hdlCtx.sendRequest(common.EchoRequest{EchoByte: 123})
logger.Debug("sent")
@ -557,7 +547,7 @@ func testEcho(hdlCtx *HandlerContext) {
logger.Debug("test echo done", "byteSent", 123, "byteReceived", echoRes.EchoByte)
}
func (ctx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
func (ctx *serverContext) wsapiHandler(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
@ -567,9 +557,9 @@ func (ctx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
return
}
peer := NewPeer(conn)
peer := newPeer(conn)
ctx.addPeer(peer)
handlerCtx := NewHandlerContext(peer, ctx)
handlerCtx := newHandlerContext(peer, ctx)
ctx.handlerContextsLock.Lock()
ctx.handlerContexts = append(ctx.handlerContexts, handlerCtx)
ctx.handlerContextsLock.Unlock()
@ -616,7 +606,7 @@ func (ctx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
}
}
func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
func (ctx *serverContext) handleUDP(data []byte, addr net.Addr) {
var punchReq common.PunchRequest
err := json.Unmarshal(data, &punchReq)
@ -627,10 +617,10 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
logger.Debugf("got punch request %+v", punchReq)
srvCtx.initiationsLock.Lock()
defer srvCtx.initiationsLock.Unlock()
ctx.initiationsLock.Lock()
defer ctx.initiationsLock.Unlock()
idx := slices.IndexFunc(srvCtx.initiations, func(i *common.Initiation) bool {
idx := slices.IndexFunc(ctx.initiations, func(i *common.Initiation) bool {
return i.AbAPunchCode == punchReq.PunchCode ||
i.AbBPunchCode == punchReq.PunchCode
})
@ -640,7 +630,7 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
return
}
matchedInitation := srvCtx.initiations[idx]
matchedInitation := ctx.initiations[idx]
logger.Debugf("matched initiation %+v", matchedInitation)
if matchedInitation.AbAPunchCode == punchReq.PunchCode {
@ -657,10 +647,10 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
logger.Debugf("finished completing initiation %+v", matchedInitation)
logger.Debug("now sending peers their addresses")
srvCtx.peersListLock.Lock()
defer srvCtx.peersListLock.Unlock()
ctx.peersListLock.Lock()
defer ctx.peersListLock.Unlock()
abA, err := srvCtx.getCtxByNick(matchedInitation.AbANick)
abA, err := ctx.getCtxByNick(matchedInitation.AbANick)
if err != nil {
logger.Debug("could not finish punching, abA not found",
@ -668,7 +658,7 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
return
}
abB, err := srvCtx.getCtxByNick(matchedInitation.AbBNick)
abB, err := ctx.getCtxByNick(matchedInitation.AbBNick)
if err != nil {
logger.Debug("could not finish punching, abB not found",
@ -676,7 +666,7 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
return
}
srvCtx.initiations = slices.DeleteFunc(srvCtx.initiations,
ctx.initiations = slices.DeleteFunc(ctx.initiations,
func(i *common.Initiation) bool {
return i.AbAPunchCode == matchedInitation.AbAPunchCode ||
i.AbBPunchCode == matchedInitation.AbAPunchCode
@ -705,8 +695,9 @@ func (srvCtx *Context) handleUDP(data []byte, addr net.Addr) {
}
}
// RunServer starts archat server according to settings passed in common.ServerSettings as the first argument
func RunServer(settings common.ServerSettings) {
srvCtx := NewContext()
srvCtx := newContext()
go func() {
for {