diff --git a/daemonconn/conninit.go b/daemonconn/conninit.go new file mode 100644 index 0000000..b652848 --- /dev/null +++ b/daemonconn/conninit.go @@ -0,0 +1,34 @@ +package daemonconn + +import ( + "net" +) + +const DefaultUnixSocketPath = "/tmp/liberum.sock" + +func UnixSocketListen(socketPath string) (<-chan DaemonConn, error) { + listener, err := net.Listen("unix", "/tmp/liberum.sock") + if err != nil { + return nil, err + } + + connChan := make(chan DaemonConn, 32) + + go func() { + var conn net.Conn + for conn, err = listener.Accept(); err == nil; { + connChan <- FromConn(conn) + } + }() + + return connChan, err +} + +func UnixSocketConnect(socketPath string) (DaemonConn, error) { + conn, err := net.Dial("unix", DefaultUnixSocketPath) + if err != nil { + return DaemonConn{}, err + } + + return FromConn(conn), nil +} diff --git a/daemonconn/daemonconn.go b/daemonconn/daemonconn.go new file mode 100644 index 0000000..28578d2 --- /dev/null +++ b/daemonconn/daemonconn.go @@ -0,0 +1,69 @@ +package daemonconn + +import ( + "encoding/json" + "errors" + "net" +) + +type Message interface { + TypeID() uint32 +} + +type DaemonConn struct { + FrameReadWriter +} + +func FromConn(conn net.Conn) DaemonConn { + return DaemonConn{FrameReadWriter{conn}} +} + +func (dc DaemonConn) ReadMessage() (Message, error) { + frame, err := dc.FrameReadWriter.ReadFrame() + if err != nil { + return nil, err + } + + var msg Message + + switch frame.Type() { + case MessageTypeEchoRequest: + msg, err = frameToMessage[EchoRequest](frame) + case MessageTypeEchoResponse: + msg, err = frameToMessage[EchoResponse](frame) + default: + err = errors.New("unknown frame type") + } + + if err != nil { + return nil, err + } + + return msg, nil +} + +func (dc DaemonConn) WriteMessage(msg Message) error { + msgBytes, err := json.Marshal(msg) + if err != nil { + return err + } + + f := frame{typeID: msg.TypeID(), valueBytes: msgBytes} + err = dc.WriteFrame(f) + if err != nil { + return err + } + + return nil +} + +func frameToMessage[M Message](frame Frame) (Message, error) { + var msg M + + err := json.Unmarshal(frame.Value(), &msg) + if err != nil { + return nil, err + } + + return msg, nil +} diff --git a/frameconn.go b/daemonconn/framereadwriter.go similarity index 98% rename from frameconn.go rename to daemonconn/framereadwriter.go index ccc8f6d..cddae5e 100644 --- a/frameconn.go +++ b/daemonconn/framereadwriter.go @@ -1,4 +1,4 @@ -package main +package daemonconn import ( "encoding/binary" diff --git a/daemonconn/messages.go b/daemonconn/messages.go new file mode 100644 index 0000000..de890aa --- /dev/null +++ b/daemonconn/messages.go @@ -0,0 +1,20 @@ +package daemonconn + +const ( + MessageTypeEchoRequest = iota + MessageTypeEchoResponse +) + +type EchoRequest struct { + EchoByte byte +} + +func (EchoRequest) TypeID() uint32 { + return MessageTypeEchoRequest +} + +type EchoResponse EchoRequest + +func (EchoResponse) TypeID() uint32 { + return MessageTypeEchoResponse +} diff --git a/main.go b/main.go index 0b33a73..e420385 100644 --- a/main.go +++ b/main.go @@ -1,226 +1,15 @@ package main import ( - "encoding/binary" - "encoding/json" - "errors" + "daemonSocketExample/daemonconn" + "daemonSocketExample/msghandlers" "github.com/sevlyar/go-daemon" - "io" "log" "log/slog" - "net" "os" "time" ) -type Message interface { - ID() uint32 -} - -const ( - EchoRequestID = iota -) - -type EchoRequest struct { - EchoByte byte -} - -type MessageHandler func(m Message) error - -var Handlers = map[uint32]MessageHandler{ - EchoRequestID: handleEchoRequest, -} - -func (EchoRequest) ID() uint32 { - return EchoRequestID -} - -func unixSocketListen() error { - listener, err := net.Listen("unix", "/tmp/liberum.sock") - if err != nil { - return err - } - - for conn, err := listener.Accept(); err == nil; { - err = handleConnection(conn) - if err != nil { - slog.Error("Error handling connection:", err) - } - } - - return nil -} - -func handleConnection(conn net.Conn) error { - msgChan := readSocketMessages(conn) - - for msg := range msgChan { - err := handleMessage(msg) - if err != nil { - return err - } - } - - return nil -} - -func readSocketMessages(conn net.Conn) chan Message { - msgChan := make(chan Message, 64) - - go func() { - for { - msgBytes, err := readMessage(conn) - if err != nil { - slog.Error("Error reading message", "error", err) - } - - msg, err := decodeMessage(msgBytes) - if err != nil { - slog.Error("Error parsing message", "error", err) - } - - msgChan <- msg - } - }() - - return msgChan -} - -func decodeMessage(msgBytes []byte) (Message, error) { - if len(msgBytes) < 4 { - return nil, errors.New("message too short to have type ID") - } - - msgID := binary.LittleEndian.Uint32(msgBytes[0:4]) - msgRest := msgBytes[4:] - var err error - var msg Message - - switch msgID { - case EchoRequestID: - var echoReq EchoRequest - err = json.Unmarshal(msgRest, &echoReq) - msg = echoReq - default: - err = errors.New("unknown message type ID") - } - - if err != nil { - return nil, err - } - - return msg, nil -} - -func readMessage(conn net.Conn) ([]byte, error) { - msgLenBuf := make([]byte, 4) - n, err := io.ReadFull(conn, msgLenBuf) - if err != nil { - return nil, err - } - - if n != 4 { - return nil, errors.New("could not read message length") - } - - msgLen := binary.LittleEndian.Uint32(msgLenBuf) - msgContent := make([]byte, msgLen) - n, err = io.ReadFull(conn, msgContent) - if err != nil { - return nil, err - } - - if uint32(n) != msgLen { - return nil, errors.New("could not read full message") - } - - return msgContent, nil -} - -func handleMessage(msg Message) error { - handler, ok := Handlers[msg.ID()] - if !ok { - return errors.New("message handler not defined for a given ID") - } - - err := handler(msg) - if err != nil { - return err - } - - return nil -} - -func handleEchoRequest(echoRequestMsg Message) error { - echoRequest := echoRequestMsg.(EchoRequest) - slog.Info("Got echo request", "EchoByte", echoRequest.EchoByte) - - return nil -} - -func writeMessage(conn net.Conn, msg Message) error { - msgBytes, err := encodeMessage(msg) - if err != nil { - return err - } - - n, err := conn.Write(msgBytes) - if err != nil { - return err - } - - if n != len(msgBytes) { - return errors.New("could not write full message") - } - - return nil -} - -func encodeMessage(msg Message) ([]byte, error) { - var msgBytes []byte - var msgJsonBytes []byte - var err error - - msgJsonBytes, err = json.Marshal(msg) - if err != nil { - return nil, err - } - - // +4 for type field length - msgLen := len(msgJsonBytes) + 4 - msgLenBytes := make([]byte, 4) - binary.LittleEndian.PutUint32(msgLenBytes, uint32(msgLen)) - msgBytes = append(msgBytes, msgLenBytes...) - - msgTypeBytes := make([]byte, 4) - binary.LittleEndian.PutUint32(msgTypeBytes, msg.ID()) - msgBytes = append(msgBytes, msgTypeBytes...) - - msgBytes = append(msgBytes, msgJsonBytes...) - - return msgBytes, nil -} - -func writeSocketMessages(conn net.Conn, msgChan <-chan Message) error { - for msg := range msgChan { - err := writeMessage(conn, msg) - if err != nil { - return err - } - } - - return nil -} - -func unixSocketConnect() (net.Conn, error) { - conn, err := net.Dial("unix", "/tmp/liberum.sock") - if err != nil { - return nil, err - } - - return conn, nil -} - func main() { ctx := &daemon.Context{ PidFileName: "liberum-daemon.pid", @@ -248,28 +37,29 @@ func main() { _ = os.Remove("/tmp/liberum.sock") - go func() { - err = unixSocketListen() - if err != nil { - slog.Error("Error listening unix socket", "error", err) - } - }() - - time.Sleep(1 * time.Second) - conn, err := unixSocketConnect() + connChan, err := daemonconn.UnixSocketListen(daemonconn.DefaultUnixSocketPath) if err != nil { panic(err) } - msgWriteChan := make(chan Message, 64) - go func() { - err = writeSocketMessages(conn, msgWriteChan) - if err != nil { - panic(err) + for conn := range connChan { + err := msghandlers.HandleDaemonConn(conn) + if err != nil { + panic(err) + } } }() - msgWriteChan <- EchoRequest{123} + conn, err := daemonconn.UnixSocketConnect(daemonconn.DefaultUnixSocketPath) + if err != nil { + panic(err) + } + + err = conn.WriteMessage(daemonconn.EchoRequest{EchoByte: 123}) + if err != nil { + panic(err) + } + time.Sleep(time.Second * 5) } diff --git a/msghandlers/msghandlers.go b/msghandlers/msghandlers.go new file mode 100644 index 0000000..6cc6856 --- /dev/null +++ b/msghandlers/msghandlers.go @@ -0,0 +1,40 @@ +package msghandlers + +import ( + "daemonSocketExample/daemonconn" + "log/slog" +) + +type MessageHandler func(daemonconn.Message) error + +var Handlers = map[uint32]MessageHandler{ + daemonconn.MessageTypeEchoRequest: handleEchoRequest, +} + +func HandleDaemonConn(conn daemonconn.DaemonConn) error { + msg, err := conn.ReadMessage() + + for err == nil { + err := handleMessage(msg) + if err != nil { + return err + } + + msg, err = conn.ReadMessage() + } + + return err +} + +func handleMessage(msg daemonconn.Message) error { + handlerFunc := Handlers[msg.TypeID()] + + return handlerFunc(msg) +} + +func handleEchoRequest(msg daemonconn.Message) error { + echoRequest := msg.(daemonconn.EchoRequest) + slog.Info("Got echo request", "echoByte", echoRequest.EchoByte) + + return nil +}