1
0

refactor(daemon): switch to gob encoding

This commit is contained in:
2024-02-28 03:04:38 +00:00
parent 1c8e1d068b
commit fc193afe92
3 changed files with 54 additions and 118 deletions

View File

@ -1,7 +1,8 @@
package api
import (
"encoding/json"
"bytes"
"encoding/gob"
"errors"
"fmt"
"net"
@ -11,9 +12,7 @@ import (
"code.dumpstack.io/tools/out-of-tree/artifact"
"code.dumpstack.io/tools/out-of-tree/distro"
"github.com/davecgh/go-spew/spew"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
var ErrInvalid = errors.New("")
@ -106,9 +105,12 @@ type Req struct {
Data []byte
}
func (r *Req) SetData(data any) {
func (r *Req) SetData(data any) (err error) {
r.Type = fmt.Sprintf("%v", reflect.TypeOf(data))
r.Data = Marshal(data)
var buf bytes.Buffer
err = gob.NewEncoder(&buf).Encode(data)
r.Data = buf.Bytes()
return
}
func (r *Req) GetData(data any) (err error) {
@ -122,32 +124,16 @@ func (r *Req) GetData(data any) (err error) {
return
}
log.Trace().Msgf("unmarshal %v", string(r.Data))
err = json.Unmarshal(r.Data, &data)
return
buf := bytes.NewBuffer(r.Data)
return gob.NewDecoder(buf).Decode(data)
}
func (r Req) Encode(conn net.Conn) {
log.Debug().Msgf("encode %v", r.Command)
err := json.NewEncoder(conn).Encode(&r)
if err != nil {
log.Fatal().Msgf("encode %v", r.Command)
}
func (r *Req) Encode(conn net.Conn) (err error) {
return gob.NewEncoder(conn).Encode(r)
}
func (r *Req) Decode(conn net.Conn) (err error) {
err = json.NewDecoder(conn).Decode(r)
return
}
func (r Req) Marshal() (bytes []byte) {
return Marshal(r)
}
func (Req) Unmarshal(data []byte) (r Req, err error) {
err = json.Unmarshal(data, &r)
log.Trace().Msgf("unmarshal %v", spew.Sdump(r))
return
return gob.NewDecoder(conn).Decode(r)
}
type Resp struct {
@ -166,9 +152,12 @@ func NewResp() (resp Resp) {
return
}
func (r *Resp) SetData(data any) {
func (r *Resp) SetData(data any) (err error) {
r.Type = fmt.Sprintf("%v", reflect.TypeOf(data))
r.Data = Marshal(data)
var buf bytes.Buffer
err = gob.NewEncoder(&buf).Encode(data)
r.Data = buf.Bytes()
return
}
func (r *Resp) GetData(data any) (err error) {
@ -182,48 +171,19 @@ func (r *Resp) GetData(data any) (err error) {
return
}
log.Trace().Msgf("unmarshal %v", string(r.Data))
err = json.Unmarshal(r.Data, &data)
return
buf := bytes.NewBuffer(r.Data)
return gob.NewDecoder(buf).Decode(data)
}
func (r *Resp) Encode(conn net.Conn) {
func (r *Resp) Encode(conn net.Conn) (err error) {
if r.Err != nil && r.Err != ErrInvalid && r.Error == "" {
r.Error = fmt.Sprintf("%v", r.Err)
}
log.Debug().Msgf("encode %v", r.UUID)
err := json.NewEncoder(conn).Encode(r)
if err != nil {
log.Fatal().Msgf("encode %v", r.UUID)
}
return gob.NewEncoder(conn).Encode(r)
}
func (r *Resp) Decode(conn net.Conn) (err error) {
err = json.NewDecoder(conn).Decode(r)
err = gob.NewDecoder(conn).Decode(r)
r.Err = ErrInvalid
return
}
func (r *Resp) Marshal() (bytes []byte) {
if r.Err != nil && r.Err != ErrInvalid && r.Error == "" {
r.Error = fmt.Sprintf("%v", r.Err)
}
return Marshal(r)
}
func (Resp) Unmarshal(data []byte) (r Resp, err error) {
err = json.Unmarshal(data, &r)
log.Trace().Msgf("unmarshal %v", spew.Sdump(r))
r.Err = ErrInvalid
return
}
func Marshal(data any) (bytes []byte) {
bytes, err := json.Marshal(data)
if err != nil {
log.Fatal().Err(err).Msgf("marshal %v", data)
}
log.Trace().Msgf("marshal %v", string(bytes))
return
}

View File

@ -1,49 +0,0 @@
package api
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestReq(t *testing.T) {
req := Req{}
req.Command = ListRepos
req.SetData(&Job{ID: 999, RepoName: "test"})
bytes := req.Marshal()
req2, err := Req{}.Unmarshal(bytes)
assert.Nil(t, err)
assert.Equal(t, req, req2)
job := Job{}
err = req2.GetData(&job)
assert.Nil(t, err)
assert.Equal(t, req2.Type, "*api.Job")
}
func TestResp(t *testing.T) {
resp := Resp{}
resp.Error = "abracadabra"
resp.SetData(&[]Repo{{}, {}})
bytes := resp.Marshal()
resp2, err := Resp{}.Unmarshal(bytes)
assert.Nil(t, err)
resp2.Err = nil // non-marshallable
assert.Equal(t, resp, resp2)
var repos []Repo
err = resp2.GetData(&repos)
assert.Nil(t, err)
assert.Equal(t, resp2.Type, "*[]api.Repo")
}