mirror of
https://github.com/originalmk/archat-server.git
synced 2024-11-20 09:38:50 +00:00
635 lines
15 KiB
Go
635 lines
15 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"os"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/charmbracelet/log"
|
|
"github.com/gorilla/websocket"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"krzyzanowski.dev/archat/common"
|
|
)
|
|
|
|
type Peer struct {
|
|
id int
|
|
conn *websocket.Conn
|
|
hasAccount bool
|
|
account *Account
|
|
}
|
|
|
|
func NewPeer(conn *websocket.Conn) *Peer {
|
|
return &Peer{-1, conn, false, nil}
|
|
}
|
|
|
|
func (p *Peer) NicknameOrEmpty() string {
|
|
if p.hasAccount {
|
|
return p.account.nickname
|
|
} else {
|
|
return ""
|
|
}
|
|
}
|
|
|
|
type Account struct {
|
|
nickname string
|
|
passHash []byte
|
|
}
|
|
|
|
func NewInitiation(abA string, abB string) *common.Initiation {
|
|
return &common.Initiation{AbANick: abA, AbBNick: abB, Stage: common.InitiationStageA}
|
|
}
|
|
|
|
type Context struct {
|
|
idCounter int
|
|
idCounterLock sync.RWMutex
|
|
peersList []*Peer
|
|
peersListLock sync.RWMutex
|
|
accounts map[string]*Account
|
|
accountsLock sync.RWMutex
|
|
initiations []*common.Initiation
|
|
initiationsLock sync.RWMutex
|
|
handlerContexts []*HandlerContext
|
|
handlerContextsLock sync.RWMutex
|
|
}
|
|
|
|
func NewContext() *Context {
|
|
return &Context{
|
|
peersList: make([]*Peer, 0),
|
|
accounts: make(map[string]*Account),
|
|
initiations: make([]*common.Initiation, 0),
|
|
}
|
|
}
|
|
|
|
// Remember to lock before calling
|
|
func (ctx *Context) getPeerByNick(nick string) (*Peer, error) {
|
|
for _, peer := range ctx.peersList {
|
|
if peer.hasAccount && peer.account.nickname == nick {
|
|
return peer, nil
|
|
}
|
|
}
|
|
return nil, errors.New("peer not found")
|
|
}
|
|
|
|
// Remember to lock before calling
|
|
func (ctx *Context) getCtxByNick(nick string) (*HandlerContext, error) {
|
|
idx := slices.IndexFunc[[]*HandlerContext, *HandlerContext](
|
|
ctx.handlerContexts,
|
|
func(handlerContext *HandlerContext) bool {
|
|
return handlerContext.peer.hasAccount && handlerContext.peer.account.nickname == nick
|
|
})
|
|
|
|
if idx != -1 {
|
|
return ctx.handlerContexts[idx], nil
|
|
}
|
|
|
|
return nil, errors.New("not found")
|
|
}
|
|
|
|
type HandlerContext struct {
|
|
peer *Peer
|
|
*Context
|
|
resFromClient chan common.RFrame
|
|
reqFromClient chan common.RFrame
|
|
rToClient chan common.RFrame
|
|
}
|
|
|
|
func NewHandlerContext(peer *Peer, srvCtx *Context) *HandlerContext {
|
|
return &HandlerContext{
|
|
peer,
|
|
srvCtx,
|
|
make(chan common.RFrame),
|
|
make(chan common.RFrame),
|
|
make(chan common.RFrame),
|
|
}
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) clientHandler(syncCtx context.Context) error {
|
|
handleNext:
|
|
for {
|
|
select {
|
|
case <-syncCtx.Done():
|
|
return nil
|
|
case reqFrame := <-hdlCtx.reqFromClient:
|
|
var res common.Response
|
|
var err error
|
|
|
|
if reqFrame.ID == common.AuthReqID {
|
|
res, err = hdlCtx.handleAuth(&reqFrame)
|
|
} else if reqFrame.ID == common.ListPeersReqID {
|
|
res, err = hdlCtx.handleListPeers(&reqFrame)
|
|
} else if reqFrame.ID == common.EchoReqID {
|
|
res, err = hdlCtx.handleEcho(&reqFrame)
|
|
} else if reqFrame.ID == common.StartChatAReqID {
|
|
res, err = hdlCtx.handleChatStartA(&reqFrame)
|
|
} else if reqFrame.ID == common.StartChatCReqID {
|
|
res, err = hdlCtx.handleChatStartC(&reqFrame)
|
|
} else {
|
|
logger.Warnf("can't handle request of ID=%d", reqFrame.ID)
|
|
continue
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Errorf("could not handle request ID=%d", reqFrame.ID)
|
|
return err
|
|
}
|
|
|
|
if res == nil {
|
|
logger.Debugf("request without response ID=%d", reqFrame.ID)
|
|
continue handleNext
|
|
}
|
|
|
|
resFrame, err := common.ResponseFrameFrom(res)
|
|
|
|
if err != nil {
|
|
logger.Errorf("could not create frame from response")
|
|
return err
|
|
}
|
|
|
|
hdlCtx.rToClient <- resFrame
|
|
}
|
|
}
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) clientWriter(syncCtx context.Context) error {
|
|
for {
|
|
select {
|
|
case <-syncCtx.Done():
|
|
return nil
|
|
case rFrame := <-hdlCtx.rToClient:
|
|
resJsonBytes, err := json.Marshal(rFrame)
|
|
|
|
if err != nil {
|
|
logger.Errorf("error marshalling frame to json")
|
|
return err
|
|
}
|
|
|
|
logger.Debugf("sending %s", string(resJsonBytes))
|
|
err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJsonBytes)
|
|
|
|
if err != nil {
|
|
logger.Errorf("error writing rframe")
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) clientReader(syncCtx context.Context) error {
|
|
for {
|
|
select {
|
|
case <-syncCtx.Done():
|
|
return nil
|
|
default:
|
|
messType, messBytes, err := hdlCtx.peer.conn.ReadMessage()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if messType != 1 {
|
|
err := hdlCtx.peer.conn.WriteMessage(websocket.CloseUnsupportedData, []byte("Only JSON text is supported"))
|
|
if err != nil {
|
|
logger.Debugf("[Server] error sending unsupported data close message")
|
|
return err
|
|
}
|
|
}
|
|
|
|
logger.Debugf("got message text: %s", strings.Trim(string(messBytes), "\n"))
|
|
var rFrame common.RFrame
|
|
err = json.Unmarshal(messBytes, &rFrame)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logger.Debugf("unmarshalled request frame (ID=%d)", rFrame.ID)
|
|
|
|
if rFrame.IsResponse() {
|
|
logger.Debug("it is response frame", "id", rFrame.ID)
|
|
hdlCtx.resFromClient <- rFrame
|
|
} else {
|
|
logger.Debug("it is request frame", "id", rFrame.ID)
|
|
hdlCtx.reqFromClient <- rFrame
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) sendRequest(req common.Request) error {
|
|
rf, err := common.RequestFrameFrom(req)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hdlCtx.rToClient <- rf
|
|
return nil
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) getResponseFrame() common.RFrame {
|
|
return <-hdlCtx.resFromClient
|
|
}
|
|
|
|
var logger = log.NewWithOptions(os.Stdout, log.Options{
|
|
ReportTimestamp: true,
|
|
TimeFormat: time.TimeOnly,
|
|
Prefix: "⚙️ Server",
|
|
})
|
|
|
|
func init() {
|
|
if common.IsProd {
|
|
logger.SetLevel(log.InfoLevel)
|
|
} else {
|
|
logger.SetLevel(log.DebugLevel)
|
|
}
|
|
}
|
|
|
|
type Matcher[T any] func(*T) bool
|
|
|
|
func (ctx *Context) removePeer(peer *Peer) {
|
|
ctx.handlerContextsLock.Lock()
|
|
ctx.peersListLock.Lock()
|
|
ctx.initiationsLock.Lock()
|
|
|
|
ctx.handlerContexts = slices.DeleteFunc[[]*HandlerContext, *HandlerContext](
|
|
ctx.handlerContexts,
|
|
func(h *HandlerContext) bool {
|
|
return h.peer.id == peer.id
|
|
})
|
|
|
|
ctx.peersList = slices.DeleteFunc[[]*Peer, *Peer](
|
|
ctx.peersList,
|
|
func(p *Peer) bool {
|
|
return p.id == peer.id
|
|
})
|
|
|
|
ctx.initiations = slices.DeleteFunc[[]*common.Initiation, *common.Initiation](
|
|
ctx.initiations,
|
|
func(i *common.Initiation) bool {
|
|
return peer.hasAccount && (peer.account.nickname == i.AbANick || peer.account.nickname == i.AbBNick)
|
|
})
|
|
|
|
// TODO: Inform the other side about peer leaving
|
|
|
|
ctx.handlerContextsLock.Unlock()
|
|
ctx.peersListLock.Unlock()
|
|
ctx.initiationsLock.Unlock()
|
|
}
|
|
|
|
func handleDisconnection(handlerCtx *HandlerContext) {
|
|
handlerCtx.removePeer(handlerCtx.peer)
|
|
logger.Infof("%s disconnected", handlerCtx.peer.conn.RemoteAddr())
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Response, err error) {
|
|
echoReq, err := common.RequestFromFrame[common.EchoRequest](*reqFrame)
|
|
|
|
if err != nil {
|
|
logger.Error("could not read request from frame")
|
|
return nil, err
|
|
}
|
|
|
|
echoRes := common.EchoResponse(echoReq)
|
|
return echoRes, nil
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res common.Response, err error) {
|
|
// Currently list peers request is empty, so we can ignore it - we won't use it
|
|
_, err = common.RequestFromFrame[common.ListPeersRequest](*reqFrame)
|
|
|
|
if err != nil {
|
|
logger.Error("could not read request from frame")
|
|
return nil, err
|
|
}
|
|
|
|
hdlCtx.peersListLock.RLock()
|
|
peersFreeze := make([]*Peer, len(hdlCtx.peersList))
|
|
copy(peersFreeze, hdlCtx.peersList)
|
|
hdlCtx.peersListLock.RUnlock()
|
|
listPeersRes := common.ListPeersResponse{PeersInfo: make([]common.PeerInfo, 0)}
|
|
|
|
for _, peer := range peersFreeze {
|
|
listPeersRes.PeersInfo = append(
|
|
listPeersRes.PeersInfo,
|
|
common.PeerInfo{
|
|
ID: peer.id,
|
|
Addr: peer.conn.RemoteAddr().String(),
|
|
HasNickname: peer.hasAccount,
|
|
Nickname: peer.NicknameOrEmpty(),
|
|
},
|
|
)
|
|
}
|
|
|
|
return listPeersRes, nil
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Response, err error) {
|
|
authReq, err := common.RequestFromFrame[common.AuthRequest](*reqFrame)
|
|
|
|
if err != nil {
|
|
logger.Error("could not read request from frame")
|
|
return nil, err
|
|
}
|
|
|
|
// Check if account already exists
|
|
hdlCtx.accountsLock.RLock()
|
|
account, ok := hdlCtx.accounts[authReq.Nickname]
|
|
hdlCtx.accountsLock.RUnlock()
|
|
var authRes *common.AuthResponse
|
|
|
|
if ok {
|
|
// Check if password matches
|
|
if bcrypt.CompareHashAndPassword(account.passHash, []byte(authReq.Password)) == nil {
|
|
authRes = &common.AuthResponse{IsSuccess: true}
|
|
hdlCtx.peersListLock.Lock()
|
|
hdlCtx.peer.hasAccount = true
|
|
hdlCtx.peer.account = account
|
|
hdlCtx.peersListLock.Unlock()
|
|
} else {
|
|
authRes = &common.AuthResponse{IsSuccess: false}
|
|
}
|
|
} else {
|
|
authRes = &common.AuthResponse{IsSuccess: true}
|
|
passHash, err := bcrypt.GenerateFromPassword([]byte(authReq.Password), bcrypt.DefaultCost)
|
|
|
|
if err != nil {
|
|
authRes = &common.AuthResponse{IsSuccess: false}
|
|
} else {
|
|
newAcc := Account{authReq.Nickname, passHash}
|
|
hdlCtx.accountsLock.Lock()
|
|
hdlCtx.accounts[newAcc.nickname] = &newAcc
|
|
hdlCtx.accountsLock.Unlock()
|
|
hdlCtx.peersListLock.Lock()
|
|
hdlCtx.peer.hasAccount = true
|
|
hdlCtx.peer.account = &newAcc
|
|
hdlCtx.peersListLock.Unlock()
|
|
}
|
|
}
|
|
|
|
return authRes, nil
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) handleChatStartA(reqFrame *common.RFrame) (res common.Response, err error) {
|
|
startChatAReq, err := common.RequestFromFrame[common.StartChatARequest](*reqFrame)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
receiverPeerCtx, err := hdlCtx.getCtxByNick(startChatAReq.Nickname)
|
|
|
|
if err != nil {
|
|
logger.Debug("receiver peer not found")
|
|
return nil, nil
|
|
}
|
|
|
|
// initation started
|
|
hdlCtx.initiationsLock.Lock()
|
|
hdlCtx.initiations = append(hdlCtx.initiations, NewInitiation(hdlCtx.peer.account.nickname, startChatAReq.Nickname))
|
|
hdlCtx.initiationsLock.Unlock()
|
|
|
|
chatStartB := common.StartChatBRequest{
|
|
Nickname: hdlCtx.peer.account.nickname,
|
|
}
|
|
|
|
chatStartBReqF, err := common.RequestFrameFrom(chatStartB)
|
|
|
|
if err != nil {
|
|
logger.Debug("chat start B req frame creation failed")
|
|
return nil, err
|
|
}
|
|
|
|
receiverPeerCtx.rToClient <- chatStartBReqF
|
|
|
|
hdlCtx.initiationsLock.Lock()
|
|
idx := slices.IndexFunc(hdlCtx.initiations, func(i *common.Initiation) bool {
|
|
return i.AbANick == hdlCtx.peer.account.nickname && i.AbBNick == startChatAReq.Nickname
|
|
})
|
|
hdlCtx.initiations[idx].Stage = common.InitiationStageB
|
|
hdlCtx.initiationsLock.Unlock()
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (hdlCtx *HandlerContext) handleChatStartC(reqFrame *common.RFrame) (res common.Response, err error) {
|
|
hdlCtx.initiationsLock.Lock()
|
|
startChatCReq, err := common.RequestFromFrame[common.StartChatCRequest](*reqFrame)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logger.Debugf("got chat start c for %s", startChatCReq.Nickname)
|
|
|
|
receiverPeerCtx, err := hdlCtx.getCtxByNick(startChatCReq.Nickname)
|
|
|
|
if err != nil {
|
|
logger.Debug("receiver peer not found")
|
|
return nil, nil
|
|
}
|
|
|
|
idx := slices.IndexFunc(hdlCtx.initiations, func(i *common.Initiation) bool {
|
|
return i.AbBNick == hdlCtx.peer.account.nickname && i.AbANick == startChatCReq.Nickname
|
|
})
|
|
|
|
if idx == -1 {
|
|
logger.Debug("initation not found, won't handle")
|
|
return nil, nil
|
|
}
|
|
|
|
if hdlCtx.initiations[idx].Stage != common.InitiationStageB {
|
|
logger.Debug("initation found, but is not in stage B, won't handle")
|
|
return nil, nil
|
|
}
|
|
|
|
hdlCtx.initiations[idx].Stage = common.InitiationStageC
|
|
|
|
aCode, err := generatePunchCode()
|
|
if err != nil {
|
|
logger.Error("failed generating punch code for a")
|
|
return nil, nil
|
|
}
|
|
|
|
bCode, err := generatePunchCode()
|
|
if err != nil {
|
|
logger.Error("failed generating punch code for b")
|
|
return nil, nil
|
|
}
|
|
|
|
hdlCtx.initiations[idx].AbAPunchCode = aCode
|
|
hdlCtx.initiations[idx].AbBPunchCode = bCode
|
|
|
|
dReqToA := common.StartChatDRequest{Nickname: hdlCtx.peer.account.nickname, PunchCode: aCode}
|
|
dReqToB := common.StartChatDRequest{Nickname: startChatCReq.Nickname, PunchCode: bCode}
|
|
|
|
err = hdlCtx.sendRequest(dReqToB)
|
|
|
|
if err != nil {
|
|
logger.Errorf("could not send chatstartd to B=%s", hdlCtx.peer.account.nickname)
|
|
return nil, nil
|
|
}
|
|
|
|
err = receiverPeerCtx.sendRequest(dReqToA)
|
|
|
|
if err != nil {
|
|
logger.Errorf("could not send chatstartd to A=%s", startChatCReq.Nickname)
|
|
return nil, nil
|
|
}
|
|
|
|
hdlCtx.initiations[idx].Stage = common.InitiationStageD
|
|
|
|
hdlCtx.initiationsLock.Unlock()
|
|
return nil, nil
|
|
}
|
|
|
|
func generatePunchCode() (string, error) {
|
|
codeBytes := make([]byte, 8)
|
|
_, err := rand.Read(codeBytes)
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for idx, cb := range codeBytes {
|
|
codeBytes[idx] = 65 + (cb % 26)
|
|
}
|
|
|
|
return string(codeBytes), nil
|
|
}
|
|
|
|
func (ctx *Context) printDebugInfo() {
|
|
ctx.peersListLock.RLock()
|
|
logger.Debug("================================ server state")
|
|
logger.Debug("displaying all connections:")
|
|
|
|
for _, p := range ctx.peersList {
|
|
nick := "-"
|
|
|
|
if p.hasAccount {
|
|
nick = p.account.nickname
|
|
}
|
|
|
|
logger.Debugf("ID#%d, Addr:%s, Auth:%t, Nick:%s", p.id, p.conn.RemoteAddr(), p.hasAccount, nick)
|
|
}
|
|
|
|
logger.Debug("displaying all initiations:")
|
|
|
|
for _, i := range ctx.initiations {
|
|
logger.Debugf("from %s to %s, stage: %d", i.AbANick, i.AbBNick, i.Stage)
|
|
}
|
|
|
|
ctx.peersListLock.RUnlock()
|
|
}
|
|
|
|
func (ctx *Context) addPeer(peer *Peer) {
|
|
ctx.idCounterLock.Lock()
|
|
ctx.idCounter++
|
|
peer.id = ctx.idCounter
|
|
ctx.idCounterLock.Unlock()
|
|
ctx.peersListLock.Lock()
|
|
ctx.peersList = append(ctx.peersList, peer)
|
|
ctx.peersListLock.Unlock()
|
|
}
|
|
|
|
func testEcho(hdlCtx *HandlerContext) {
|
|
logger.Debug("sending echo request...")
|
|
_ = hdlCtx.sendRequest(common.EchoRequest{EchoByte: 123})
|
|
logger.Debug("sent")
|
|
echoResF := hdlCtx.getResponseFrame()
|
|
logger.Debug("got response")
|
|
echoRes, err := common.ResponseFromFrame[common.EchoResponse](echoResF)
|
|
|
|
if err != nil {
|
|
logger.Error(err)
|
|
return
|
|
}
|
|
|
|
logger.Debug("test echo done", "byteSent", 123, "byteReceived", echoRes.EchoByte)
|
|
}
|
|
|
|
func (ctx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
|
|
upgrader := websocket.Upgrader{}
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
|
|
if err != nil {
|
|
logger.Errorf("upgrade failed")
|
|
return
|
|
}
|
|
|
|
peer := NewPeer(conn)
|
|
ctx.addPeer(peer)
|
|
handlerCtx := NewHandlerContext(peer, ctx)
|
|
ctx.handlerContextsLock.Lock()
|
|
ctx.handlerContexts = append(ctx.handlerContexts, handlerCtx)
|
|
ctx.handlerContextsLock.Unlock()
|
|
defer handleDisconnection(handlerCtx)
|
|
|
|
defer func(conn *websocket.Conn) {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
logger.Error(err)
|
|
}
|
|
}(conn)
|
|
|
|
logger.Infof("%s connected", conn.RemoteAddr())
|
|
|
|
errGroup, syncCtx := errgroup.WithContext(context.Background())
|
|
|
|
errGroup.Go(func() error {
|
|
return handlerCtx.clientHandler(syncCtx)
|
|
})
|
|
|
|
errGroup.Go(func() error {
|
|
return handlerCtx.clientWriter(syncCtx)
|
|
})
|
|
|
|
errGroup.Go(func() error {
|
|
return handlerCtx.clientReader(syncCtx)
|
|
})
|
|
|
|
errGroup.Go(func() error {
|
|
<-syncCtx.Done()
|
|
time.Sleep(time.Second * 3)
|
|
close(handlerCtx.rToClient)
|
|
close(handlerCtx.resFromClient)
|
|
close(handlerCtx.reqFromClient)
|
|
return conn.Close()
|
|
})
|
|
|
|
testEcho(handlerCtx)
|
|
err = errGroup.Wait()
|
|
|
|
if err != nil {
|
|
logger.Error(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func RunServer() {
|
|
srvCtx := NewContext()
|
|
|
|
go func() {
|
|
for {
|
|
srvCtx.printDebugInfo()
|
|
time.Sleep(time.Second * 5)
|
|
}
|
|
}()
|
|
|
|
http.HandleFunc("/wsapi", srvCtx.wsapiHandler)
|
|
logger.Info("Starting server...")
|
|
err := http.ListenAndServe(":8080", nil)
|
|
|
|
if err != nil {
|
|
logger.Error(err)
|
|
}
|
|
}
|