258 lines
4.4 KiB
Go
258 lines
4.4 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"github.com/sevlyar/go-daemon"
|
|
"io"
|
|
"log"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
)
|
|
|
|
type Message interface {
|
|
ID() uint32
|
|
}
|
|
|
|
const (
|
|
EchoRequestID = iota
|
|
)
|
|
|
|
type EchoRequest struct {
|
|
EchoByte byte
|
|
}
|
|
|
|
func (EchoRequest) ID() uint32 {
|
|
return EchoRequestID
|
|
}
|
|
|
|
func unixSocketListen() error {
|
|
listener, err := net.Listen("unix", "/tmp/liberum.sock")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for conn, err := listener.Accept(); err == nil; {
|
|
handleConnection(conn)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func handleConnection(conn net.Conn) {
|
|
msgChan := readSocketMessages(conn)
|
|
|
|
for msg := range msgChan {
|
|
handleMessage(msg)
|
|
}
|
|
}
|
|
|
|
func readSocketMessages(conn net.Conn) chan Message {
|
|
msgChan := make(chan Message, 64)
|
|
|
|
go func() {
|
|
for {
|
|
msgBytes, err := readMessage(conn)
|
|
if err != nil {
|
|
slog.Error("Error reading message", "error", err)
|
|
}
|
|
|
|
msg, err := decodeMessage(msgBytes)
|
|
if err != nil {
|
|
slog.Error("Error parsing message", "error", err)
|
|
}
|
|
|
|
msgChan <- msg
|
|
}
|
|
}()
|
|
|
|
return msgChan
|
|
}
|
|
|
|
func decodeMessage(msgBytes []byte) (Message, error) {
|
|
if len(msgBytes) < 4 {
|
|
return nil, errors.New("message too short to have type ID")
|
|
}
|
|
|
|
msgID := binary.LittleEndian.Uint32(msgBytes[0:4])
|
|
msgRest := msgBytes[4:]
|
|
var err error
|
|
var msg Message
|
|
|
|
switch msgID {
|
|
case EchoRequestID:
|
|
var echoReq EchoRequest
|
|
err = json.Unmarshal(msgRest, &echoReq)
|
|
msg = echoReq
|
|
default:
|
|
err = errors.New("unknown message type ID")
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return msg, nil
|
|
}
|
|
|
|
func readMessage(conn net.Conn) ([]byte, error) {
|
|
msgLenBuf := make([]byte, 4)
|
|
n, err := io.ReadFull(conn, msgLenBuf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if n != 4 {
|
|
return nil, errors.New("could not read message length")
|
|
}
|
|
|
|
msgLen := binary.LittleEndian.Uint32(msgLenBuf)
|
|
msgContent := make([]byte, msgLen)
|
|
n, err = io.ReadFull(conn, msgContent)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if uint32(n) != msgLen {
|
|
return nil, errors.New("could not read full message")
|
|
}
|
|
|
|
return msgContent, nil
|
|
}
|
|
|
|
func handleMessage(msg Message) {
|
|
switch msg.ID() {
|
|
case EchoRequestID:
|
|
handleEchoRequest(msg.(EchoRequest))
|
|
}
|
|
}
|
|
|
|
func handleEchoRequest(req EchoRequest) {
|
|
slog.Info("Got echo request", "EchoByte", req.EchoByte)
|
|
}
|
|
|
|
func writeMessage(conn net.Conn, msg Message) error {
|
|
msgBytes, err := encodeMessage(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
n, err := conn.Write(msgBytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if n != len(msgBytes) {
|
|
return errors.New("could not write full message")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func encodeMessage(msg Message) ([]byte, error) {
|
|
var msgBytes []byte
|
|
var msgJsonBytes []byte
|
|
var err error
|
|
|
|
switch msg.ID() {
|
|
case EchoRequestID:
|
|
msgJsonBytes, err = json.Marshal(msg.(EchoRequest))
|
|
default:
|
|
err = errors.New("unknown message type")
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// +4 for type field length
|
|
msgLen := len(msgJsonBytes) + 4
|
|
msgLenBytes := make([]byte, 4)
|
|
binary.LittleEndian.PutUint32(msgLenBytes, uint32(msgLen))
|
|
msgBytes = append(msgBytes, msgLenBytes...)
|
|
|
|
msgTypeBytes := make([]byte, 4)
|
|
binary.LittleEndian.PutUint32(msgTypeBytes, msg.ID())
|
|
msgBytes = append(msgBytes, msgTypeBytes...)
|
|
|
|
msgBytes = append(msgBytes, msgJsonBytes...)
|
|
|
|
return msgBytes, nil
|
|
}
|
|
|
|
func writeSocketMessages(conn net.Conn, msgChan <-chan Message) error {
|
|
for msg := range msgChan {
|
|
err := writeMessage(conn, msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func unixSocketConnect() (net.Conn, error) {
|
|
conn, err := net.Dial("unix", "/tmp/liberum.sock")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func main() {
|
|
ctx := &daemon.Context{
|
|
PidFileName: "liberum-daemon.pid",
|
|
PidFilePerm: 0644,
|
|
LogFileName: "liberum-daemon.log",
|
|
LogFilePerm: 0640,
|
|
WorkDir: "./",
|
|
Umask: 027,
|
|
}
|
|
|
|
d, err := ctx.Reborn()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if d != nil {
|
|
return
|
|
}
|
|
|
|
defer func(ctx *daemon.Context) {
|
|
slog.Info("Stopping daemon")
|
|
_ = ctx.Release()
|
|
}(ctx)
|
|
log.Println("Daemon started")
|
|
|
|
_ = os.Remove("/tmp/liberum.sock")
|
|
|
|
go func() {
|
|
err = unixSocketListen()
|
|
if err != nil {
|
|
slog.Error("Error listening unix socket", "error", err)
|
|
}
|
|
}()
|
|
|
|
time.Sleep(1 * time.Second)
|
|
conn, err := unixSocketConnect()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
msgWriteChan := make(chan Message, 64)
|
|
|
|
go func() {
|
|
err = writeSocketMessages(conn, msgWriteChan)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}()
|
|
|
|
msgWriteChan <- EchoRequest{123}
|
|
time.Sleep(time.Second * 5)
|
|
}
|