archat-server/server/server.go

389 lines
8.8 KiB
Go

package server
import (
"encoding/json"
"golang.org/x/sync/errgroup"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/charmbracelet/log"
"github.com/gorilla/websocket"
"golang.org/x/crypto/bcrypt"
"krzyzanowski.dev/p2pchat/common"
)
type Account struct {
nickname string
passHash []byte
}
type Context struct {
idCounter int
idCounterLock sync.RWMutex
peersList []*Peer
peersListLock sync.RWMutex
accounts map[string]*Account
accountsLock sync.RWMutex
}
type HandlerContext struct {
peer *Peer
*Context
resFromClient chan common.RFrame
reqFromClient chan common.RFrame
rToClient chan common.RFrame
}
func NewHandlerContext(peer *Peer, srvCtx *Context) *HandlerContext {
return &HandlerContext{
peer,
srvCtx,
make(chan common.RFrame),
make(chan common.RFrame),
make(chan common.RFrame),
}
}
func (hdlCtx *HandlerContext) clientHandler() error {
for {
reqFrame := <-hdlCtx.reqFromClient
var res common.Response
var err error
if reqFrame.ID == common.AuthReqID {
res, err = hdlCtx.handleAuth(&reqFrame)
} else if reqFrame.ID == common.ListPeersReqID {
res, err = hdlCtx.handleListPeers(&reqFrame)
} else if reqFrame.ID == common.EchoReqID {
res, err = hdlCtx.handleEcho(&reqFrame)
}
if err != nil {
logger.Errorf("could not handle request ID=%d", reqFrame.ID)
return err
}
resFrame, err := common.ResponseFrameFrom(res)
if err != nil {
logger.Errorf("could not create frame from response")
return err
}
hdlCtx.rToClient <- resFrame
}
}
func (hdlCtx *HandlerContext) clientWriter() error {
for {
rFrame := <-hdlCtx.rToClient
resJsonBytes, err := json.Marshal(rFrame)
if err != nil {
logger.Errorf("error marshalling frame to json")
return err
}
logger.Debugf("sending %s", string(resJsonBytes))
err = hdlCtx.peer.conn.WriteMessage(websocket.TextMessage, resJsonBytes)
if err != nil {
logger.Errorf("error writing rframe")
return err
}
}
}
func (hdlCtx *HandlerContext) clientReader() error {
for {
messType, messBytes, err := hdlCtx.peer.conn.ReadMessage()
if err != nil {
return err
}
if messType != 1 {
err := hdlCtx.peer.conn.WriteMessage(websocket.CloseUnsupportedData, []byte("Only JSON text is supported"))
if err != nil {
logger.Debugf("[Server] error sending unsupported data close message")
return err
}
}
logger.Debugf("got message text: %s", strings.Trim(string(messBytes), "\n"))
var rFrame common.RFrame
err = json.Unmarshal(messBytes, &rFrame)
if err != nil {
return err
}
logger.Debugf("unmarshalled request frame (ID=%d)", rFrame.ID)
if rFrame.IsResponse() {
logger.Debug("it is response frame", "id", rFrame.ID)
hdlCtx.resFromClient <- rFrame
} else {
logger.Debug("it is request frame", "id", rFrame.ID)
hdlCtx.reqFromClient <- rFrame
}
}
}
func (hdlCtx *HandlerContext) sendRequest(req common.Request) error {
rf, err := common.RequestFrameFrom(req)
if err != nil {
return err
}
hdlCtx.rToClient <- rf
return nil
}
func (hdlCtx *HandlerContext) getResponseFrame() common.RFrame {
return <-hdlCtx.resFromClient
}
type Peer struct {
id int
conn *websocket.Conn
hasAccount bool
account *Account
}
var logger = log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true,
TimeFormat: time.TimeOnly,
Prefix: "⚙️ Server",
})
func init() {
if common.IsProd {
logger.SetLevel(log.InfoLevel)
} else {
logger.SetLevel(log.DebugLevel)
}
}
func NewPeer(conn *websocket.Conn) *Peer {
return &Peer{-1, conn, false, nil}
}
func peerSliceIndexOf(s []*Peer, id int) int {
i := 0
var p *Peer
for i, p = range s {
if p.id == id {
break
}
}
return i
}
func peerSliceRemove(s *[]*Peer, i int) {
(*s)[i] = (*s)[len(*s)-1]
*s = (*s)[:len(*s)-1]
}
func (srvCtx *Context) removePeer(peer *Peer) {
srvCtx.peersListLock.Lock()
peerSliceRemove(&srvCtx.peersList, peerSliceIndexOf(srvCtx.peersList, peer.id))
srvCtx.peersListLock.Unlock()
}
func handleDisconnection(handlerCtx *HandlerContext) {
handlerCtx.removePeer(handlerCtx.peer)
logger.Infof("%s disconnected", handlerCtx.peer.conn.RemoteAddr())
}
func (hdlCtx *HandlerContext) handleEcho(reqFrame *common.RFrame) (res common.Response, err error) {
echoReq, err := common.RequestFromFrame[common.EchoRequest](*reqFrame)
if err != nil {
logger.Error("could not read request from frame")
return nil, err
}
echoRes := common.EchoResponse(echoReq)
return echoRes, nil
}
func (hdlCtx *HandlerContext) handleListPeers(reqFrame *common.RFrame) (res common.Response, err error) {
// Currently list peers request is empty, so we can ignore it - we won't use it
_, err = common.RequestFromFrame[common.ListPeersRequest](*reqFrame)
if err != nil {
logger.Error("could not read request from frame")
return nil, err
}
hdlCtx.peersListLock.RLock()
peersFreeze := make([]*Peer, len(hdlCtx.peersList))
copy(peersFreeze, hdlCtx.peersList)
hdlCtx.peersListLock.RUnlock()
listPeersRes := common.ListPeersResponse{PeersInfo: make([]common.PeerInfo, 0)}
for _, peer := range peersFreeze {
listPeersRes.PeersInfo = append(
listPeersRes.PeersInfo,
common.PeerInfo{
ID: peer.id,
Addr: peer.conn.RemoteAddr().String(),
HasNickname: peer.hasAccount,
Nickname: peer.account.nickname,
},
)
}
return listPeersRes, nil
}
func (hdlCtx *HandlerContext) handleAuth(reqFrame *common.RFrame) (res common.Response, err error) {
authReq, err := common.RequestFromFrame[common.AuthRequest](*reqFrame)
if err != nil {
logger.Error("could not read request from frame")
return nil, err
}
// Check if account already exists
hdlCtx.accountsLock.RLock()
account, ok := hdlCtx.accounts[authReq.Nickname]
hdlCtx.accountsLock.RUnlock()
var authRes *common.AuthResponse
if ok {
// Check if password matches
if bcrypt.CompareHashAndPassword(account.passHash, []byte(authReq.Password)) == nil {
authRes = &common.AuthResponse{IsSuccess: true}
hdlCtx.peersListLock.Lock()
hdlCtx.peer.hasAccount = true
hdlCtx.peer.account = account
hdlCtx.peersListLock.Unlock()
} else {
authRes = &common.AuthResponse{IsSuccess: false}
}
} else {
authRes = &common.AuthResponse{IsSuccess: true}
passHash, err := bcrypt.GenerateFromPassword([]byte(authReq.Password), bcrypt.DefaultCost)
if err != nil {
authRes = &common.AuthResponse{IsSuccess: false}
} else {
newAcc := Account{authReq.Nickname, passHash}
hdlCtx.accountsLock.Lock()
hdlCtx.accounts[newAcc.nickname] = &newAcc
hdlCtx.accountsLock.Unlock()
hdlCtx.peersListLock.Lock()
hdlCtx.peer.hasAccount = true
hdlCtx.peer.account = &newAcc
hdlCtx.peersListLock.Unlock()
}
}
return authRes, nil
}
func (srvCtx *Context) printConnectedPeers() {
srvCtx.peersListLock.RLock()
logger.Debug("displaying all connections:")
for _, p := range srvCtx.peersList {
nick := "-"
if p.hasAccount {
nick = p.account.nickname
}
logger.Debugf("ID#%d, Addr:%s, Auth:%t, Nick:%s", p.id, p.conn.RemoteAddr(), p.hasAccount, nick)
}
srvCtx.peersListLock.RUnlock()
}
func (srvCtx *Context) addPeer(peer *Peer) {
srvCtx.idCounterLock.Lock()
srvCtx.idCounter++
peer.id = srvCtx.idCounter
srvCtx.idCounterLock.Unlock()
srvCtx.peersListLock.Lock()
srvCtx.peersList = append(srvCtx.peersList, peer)
srvCtx.peersListLock.Unlock()
}
func testEcho(hdlCtx *HandlerContext) {
logger.Debug("sending echo request...")
_ = hdlCtx.sendRequest(common.EchoRequest{EchoByte: 123})
logger.Debug("sent")
echoResF := hdlCtx.getResponseFrame()
logger.Debug("got response")
echoRes, err := common.ResponseFromFrame[common.EchoResponse](echoResF)
if err != nil {
logger.Error(err)
return
}
logger.Debug("test echo done", "byteSent", 123, "byteReceived", echoRes.EchoByte)
}
func (srvCtx *Context) wsapiHandler(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Errorf("upgrade failed")
return
}
peer := NewPeer(conn)
srvCtx.addPeer(peer)
handlerCtx := NewHandlerContext(peer, srvCtx)
defer handleDisconnection(handlerCtx)
defer func(conn *websocket.Conn) {
err := conn.Close()
if err != nil {
logger.Error(err)
}
}(conn)
logger.Infof("%s connected", conn.RemoteAddr())
errGroup := new(errgroup.Group)
errGroup.Go(handlerCtx.clientHandler)
errGroup.Go(handlerCtx.clientWriter)
errGroup.Go(handlerCtx.clientReader)
testEcho(handlerCtx)
err = errGroup.Wait()
if err != nil {
logger.Error(err)
return
}
}
func RunServer() {
srvCtx := &Context{peersList: make([]*Peer, 0), accounts: make(map[string]*Account)}
go func() {
for {
srvCtx.printConnectedPeers()
time.Sleep(time.Second * 5)
}
}()
http.HandleFunc("/wsapi", srvCtx.wsapiHandler)
logger.Info("Starting server...")
err := http.ListenAndServe(":8080", nil)
if err != nil {
logger.Error(err)
}
}