276 lines
4.7 KiB
Go

package main
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/sevlyar/go-daemon"
"io"
"log"
"log/slog"
"net"
"os"
"time"
)
type Message interface {
ID() uint32
}
const (
EchoRequestID = iota
)
type EchoRequest struct {
EchoByte byte
}
type MessageHandler func(m Message) error
var Handlers = map[uint32]MessageHandler{
EchoRequestID: handleEchoRequest,
}
func (EchoRequest) ID() uint32 {
return EchoRequestID
}
func unixSocketListen() error {
listener, err := net.Listen("unix", "/tmp/liberum.sock")
if err != nil {
return err
}
for conn, err := listener.Accept(); err == nil; {
err = handleConnection(conn)
if err != nil {
slog.Error("Error handling connection:", err)
}
}
return nil
}
func handleConnection(conn net.Conn) error {
msgChan := readSocketMessages(conn)
for msg := range msgChan {
err := handleMessage(msg)
if err != nil {
return err
}
}
return nil
}
func readSocketMessages(conn net.Conn) chan Message {
msgChan := make(chan Message, 64)
go func() {
for {
msgBytes, err := readMessage(conn)
if err != nil {
slog.Error("Error reading message", "error", err)
}
msg, err := decodeMessage(msgBytes)
if err != nil {
slog.Error("Error parsing message", "error", err)
}
msgChan <- msg
}
}()
return msgChan
}
func decodeMessage(msgBytes []byte) (Message, error) {
if len(msgBytes) < 4 {
return nil, errors.New("message too short to have type ID")
}
msgID := binary.LittleEndian.Uint32(msgBytes[0:4])
msgRest := msgBytes[4:]
var err error
var msg Message
switch msgID {
case EchoRequestID:
var echoReq EchoRequest
err = json.Unmarshal(msgRest, &echoReq)
msg = echoReq
default:
err = errors.New("unknown message type ID")
}
if err != nil {
return nil, err
}
return msg, nil
}
func readMessage(conn net.Conn) ([]byte, error) {
msgLenBuf := make([]byte, 4)
n, err := io.ReadFull(conn, msgLenBuf)
if err != nil {
return nil, err
}
if n != 4 {
return nil, errors.New("could not read message length")
}
msgLen := binary.LittleEndian.Uint32(msgLenBuf)
msgContent := make([]byte, msgLen)
n, err = io.ReadFull(conn, msgContent)
if err != nil {
return nil, err
}
if uint32(n) != msgLen {
return nil, errors.New("could not read full message")
}
return msgContent, nil
}
func handleMessage(msg Message) error {
handler, ok := Handlers[msg.ID()]
if !ok {
return errors.New("message handler not defined for a given ID")
}
err := handler(msg)
if err != nil {
return err
}
return nil
}
func handleEchoRequest(echoRequestMsg Message) error {
echoRequest := echoRequestMsg.(EchoRequest)
slog.Info("Got echo request", "EchoByte", echoRequest.EchoByte)
return nil
}
func writeMessage(conn net.Conn, msg Message) error {
msgBytes, err := encodeMessage(msg)
if err != nil {
return err
}
n, err := conn.Write(msgBytes)
if err != nil {
return err
}
if n != len(msgBytes) {
return errors.New("could not write full message")
}
return nil
}
func encodeMessage(msg Message) ([]byte, error) {
var msgBytes []byte
var msgJsonBytes []byte
var err error
msgJsonBytes, err = json.Marshal(msg)
if err != nil {
return nil, err
}
// +4 for type field length
msgLen := len(msgJsonBytes) + 4
msgLenBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(msgLenBytes, uint32(msgLen))
msgBytes = append(msgBytes, msgLenBytes...)
msgTypeBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(msgTypeBytes, msg.ID())
msgBytes = append(msgBytes, msgTypeBytes...)
msgBytes = append(msgBytes, msgJsonBytes...)
return msgBytes, nil
}
func writeSocketMessages(conn net.Conn, msgChan <-chan Message) error {
for msg := range msgChan {
err := writeMessage(conn, msg)
if err != nil {
return err
}
}
return nil
}
func unixSocketConnect() (net.Conn, error) {
conn, err := net.Dial("unix", "/tmp/liberum.sock")
if err != nil {
return nil, err
}
return conn, nil
}
func main() {
ctx := &daemon.Context{
PidFileName: "liberum-daemon.pid",
PidFilePerm: 0644,
LogFileName: "liberum-daemon.log",
LogFilePerm: 0640,
WorkDir: "./",
Umask: 027,
}
d, err := ctx.Reborn()
if err != nil {
panic(err)
}
if d != nil {
return
}
defer func(ctx *daemon.Context) {
slog.Info("Stopping daemon")
_ = ctx.Release()
}(ctx)
log.Println("Daemon started")
_ = os.Remove("/tmp/liberum.sock")
go func() {
err = unixSocketListen()
if err != nil {
slog.Error("Error listening unix socket", "error", err)
}
}()
time.Sleep(1 * time.Second)
conn, err := unixSocketConnect()
if err != nil {
panic(err)
}
msgWriteChan := make(chan Message, 64)
go func() {
err = writeSocketMessages(conn, msgWriteChan)
if err != nil {
panic(err)
}
}()
msgWriteChan <- EchoRequest{123}
time.Sleep(time.Second * 5)
}