Pretty big cleanup

This commit is contained in:
Maciej Krzyżanowski 2024-12-28 19:11:50 +01:00
parent e8d917fcef
commit fa59d2403a
6 changed files with 182 additions and 229 deletions

34
daemonconn/conninit.go Normal file
View File

@ -0,0 +1,34 @@
package daemonconn
import (
"net"
)
const DefaultUnixSocketPath = "/tmp/liberum.sock"
func UnixSocketListen(socketPath string) (<-chan DaemonConn, error) {
listener, err := net.Listen("unix", "/tmp/liberum.sock")
if err != nil {
return nil, err
}
connChan := make(chan DaemonConn, 32)
go func() {
var conn net.Conn
for conn, err = listener.Accept(); err == nil; {
connChan <- FromConn(conn)
}
}()
return connChan, err
}
func UnixSocketConnect(socketPath string) (DaemonConn, error) {
conn, err := net.Dial("unix", DefaultUnixSocketPath)
if err != nil {
return DaemonConn{}, err
}
return FromConn(conn), nil
}

69
daemonconn/daemonconn.go Normal file
View File

@ -0,0 +1,69 @@
package daemonconn
import (
"encoding/json"
"errors"
"net"
)
type Message interface {
TypeID() uint32
}
type DaemonConn struct {
FrameReadWriter
}
func FromConn(conn net.Conn) DaemonConn {
return DaemonConn{FrameReadWriter{conn}}
}
func (dc DaemonConn) ReadMessage() (Message, error) {
frame, err := dc.FrameReadWriter.ReadFrame()
if err != nil {
return nil, err
}
var msg Message
switch frame.Type() {
case MessageTypeEchoRequest:
msg, err = frameToMessage[EchoRequest](frame)
case MessageTypeEchoResponse:
msg, err = frameToMessage[EchoResponse](frame)
default:
err = errors.New("unknown frame type")
}
if err != nil {
return nil, err
}
return msg, nil
}
func (dc DaemonConn) WriteMessage(msg Message) error {
msgBytes, err := json.Marshal(msg)
if err != nil {
return err
}
f := frame{typeID: msg.TypeID(), valueBytes: msgBytes}
err = dc.WriteFrame(f)
if err != nil {
return err
}
return nil
}
func frameToMessage[M Message](frame Frame) (Message, error) {
var msg M
err := json.Unmarshal(frame.Value(), &msg)
if err != nil {
return nil, err
}
return msg, nil
}

View File

@ -1,4 +1,4 @@
package main package daemonconn
import ( import (
"encoding/binary" "encoding/binary"

20
daemonconn/messages.go Normal file
View File

@ -0,0 +1,20 @@
package daemonconn
const (
MessageTypeEchoRequest = iota
MessageTypeEchoResponse
)
type EchoRequest struct {
EchoByte byte
}
func (EchoRequest) TypeID() uint32 {
return MessageTypeEchoRequest
}
type EchoResponse EchoRequest
func (EchoResponse) TypeID() uint32 {
return MessageTypeEchoResponse
}

246
main.go
View File

@ -1,226 +1,15 @@
package main package main
import ( import (
"encoding/binary" "daemonSocketExample/daemonconn"
"encoding/json" "daemonSocketExample/msghandlers"
"errors"
"github.com/sevlyar/go-daemon" "github.com/sevlyar/go-daemon"
"io"
"log" "log"
"log/slog" "log/slog"
"net"
"os" "os"
"time" "time"
) )
type Message interface {
ID() uint32
}
const (
EchoRequestID = iota
)
type EchoRequest struct {
EchoByte byte
}
type MessageHandler func(m Message) error
var Handlers = map[uint32]MessageHandler{
EchoRequestID: handleEchoRequest,
}
func (EchoRequest) ID() uint32 {
return EchoRequestID
}
func unixSocketListen() error {
listener, err := net.Listen("unix", "/tmp/liberum.sock")
if err != nil {
return err
}
for conn, err := listener.Accept(); err == nil; {
err = handleConnection(conn)
if err != nil {
slog.Error("Error handling connection:", err)
}
}
return nil
}
func handleConnection(conn net.Conn) error {
msgChan := readSocketMessages(conn)
for msg := range msgChan {
err := handleMessage(msg)
if err != nil {
return err
}
}
return nil
}
func readSocketMessages(conn net.Conn) chan Message {
msgChan := make(chan Message, 64)
go func() {
for {
msgBytes, err := readMessage(conn)
if err != nil {
slog.Error("Error reading message", "error", err)
}
msg, err := decodeMessage(msgBytes)
if err != nil {
slog.Error("Error parsing message", "error", err)
}
msgChan <- msg
}
}()
return msgChan
}
func decodeMessage(msgBytes []byte) (Message, error) {
if len(msgBytes) < 4 {
return nil, errors.New("message too short to have type ID")
}
msgID := binary.LittleEndian.Uint32(msgBytes[0:4])
msgRest := msgBytes[4:]
var err error
var msg Message
switch msgID {
case EchoRequestID:
var echoReq EchoRequest
err = json.Unmarshal(msgRest, &echoReq)
msg = echoReq
default:
err = errors.New("unknown message type ID")
}
if err != nil {
return nil, err
}
return msg, nil
}
func readMessage(conn net.Conn) ([]byte, error) {
msgLenBuf := make([]byte, 4)
n, err := io.ReadFull(conn, msgLenBuf)
if err != nil {
return nil, err
}
if n != 4 {
return nil, errors.New("could not read message length")
}
msgLen := binary.LittleEndian.Uint32(msgLenBuf)
msgContent := make([]byte, msgLen)
n, err = io.ReadFull(conn, msgContent)
if err != nil {
return nil, err
}
if uint32(n) != msgLen {
return nil, errors.New("could not read full message")
}
return msgContent, nil
}
func handleMessage(msg Message) error {
handler, ok := Handlers[msg.ID()]
if !ok {
return errors.New("message handler not defined for a given ID")
}
err := handler(msg)
if err != nil {
return err
}
return nil
}
func handleEchoRequest(echoRequestMsg Message) error {
echoRequest := echoRequestMsg.(EchoRequest)
slog.Info("Got echo request", "EchoByte", echoRequest.EchoByte)
return nil
}
func writeMessage(conn net.Conn, msg Message) error {
msgBytes, err := encodeMessage(msg)
if err != nil {
return err
}
n, err := conn.Write(msgBytes)
if err != nil {
return err
}
if n != len(msgBytes) {
return errors.New("could not write full message")
}
return nil
}
func encodeMessage(msg Message) ([]byte, error) {
var msgBytes []byte
var msgJsonBytes []byte
var err error
msgJsonBytes, err = json.Marshal(msg)
if err != nil {
return nil, err
}
// +4 for type field length
msgLen := len(msgJsonBytes) + 4
msgLenBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(msgLenBytes, uint32(msgLen))
msgBytes = append(msgBytes, msgLenBytes...)
msgTypeBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(msgTypeBytes, msg.ID())
msgBytes = append(msgBytes, msgTypeBytes...)
msgBytes = append(msgBytes, msgJsonBytes...)
return msgBytes, nil
}
func writeSocketMessages(conn net.Conn, msgChan <-chan Message) error {
for msg := range msgChan {
err := writeMessage(conn, msg)
if err != nil {
return err
}
}
return nil
}
func unixSocketConnect() (net.Conn, error) {
conn, err := net.Dial("unix", "/tmp/liberum.sock")
if err != nil {
return nil, err
}
return conn, nil
}
func main() { func main() {
ctx := &daemon.Context{ ctx := &daemon.Context{
PidFileName: "liberum-daemon.pid", PidFileName: "liberum-daemon.pid",
@ -248,28 +37,29 @@ func main() {
_ = os.Remove("/tmp/liberum.sock") _ = os.Remove("/tmp/liberum.sock")
go func() { connChan, err := daemonconn.UnixSocketListen(daemonconn.DefaultUnixSocketPath)
err = unixSocketListen()
if err != nil {
slog.Error("Error listening unix socket", "error", err)
}
}()
time.Sleep(1 * time.Second)
conn, err := unixSocketConnect()
if err != nil { if err != nil {
panic(err) panic(err)
} }
msgWriteChan := make(chan Message, 64)
go func() { go func() {
err = writeSocketMessages(conn, msgWriteChan) for conn := range connChan {
if err != nil { err := msghandlers.HandleDaemonConn(conn)
panic(err) if err != nil {
panic(err)
}
} }
}() }()
msgWriteChan <- EchoRequest{123} conn, err := daemonconn.UnixSocketConnect(daemonconn.DefaultUnixSocketPath)
if err != nil {
panic(err)
}
err = conn.WriteMessage(daemonconn.EchoRequest{EchoByte: 123})
if err != nil {
panic(err)
}
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
} }

View File

@ -0,0 +1,40 @@
package msghandlers
import (
"daemonSocketExample/daemonconn"
"log/slog"
)
type MessageHandler func(daemonconn.Message) error
var Handlers = map[uint32]MessageHandler{
daemonconn.MessageTypeEchoRequest: handleEchoRequest,
}
func HandleDaemonConn(conn daemonconn.DaemonConn) error {
msg, err := conn.ReadMessage()
for err == nil {
err := handleMessage(msg)
if err != nil {
return err
}
msg, err = conn.ReadMessage()
}
return err
}
func handleMessage(msg daemonconn.Message) error {
handlerFunc := Handlers[msg.TypeID()]
return handlerFunc(msg)
}
func handleEchoRequest(msg daemonconn.Message) error {
echoRequest := msg.(daemonconn.EchoRequest)
slog.Info("Got echo request", "echoByte", echoRequest.EchoByte)
return nil
}