Commit a7629a8c authored by wei.xuan's avatar wei.xuan

feat: opt

parent bab9c2e0
......@@ -55,8 +55,7 @@ func (client *Client) register() {
packet := protocol.TypeRegister.CreatePacket()
packet.Extra = client.Options.ClientID
packet.Data = protocol.EncodeExtra(client.Options.ExtraInfo)
encode, _ := client.Codec.Encode(packet)
if err := client.WriteAndFlush(encode); err != nil {
if err := client.WriteAndFlush(packet); err != nil {
logger.Error().Msgf("register to nat server with error %s", err.Error())
} else {
logger.Info().Msg("register to nat server success")
......@@ -68,7 +67,6 @@ func (client *Client) handleNatEvent() {
reader := bufio.NewReader(client.natTunnel.Load().(net.Conn))
majoraPacket, err := client.Codec.Decode(reader)
if errors.Is(err, io.EOF) {
logger.Error().Msgf("*********disconnect******")
client.reConnect()
continue
}
......@@ -83,15 +81,15 @@ func (client *Client) handleNatEvent() {
case protocol.TypeHeartbeat:
client.handleHeartbeatMessage()
case protocol.TypeConnect:
go client.handleConnect(majoraPacket)
client.handleConnect(majoraPacket)
case protocol.TypeTransfer:
go client.handleTransfer(majoraPacket)
client.handleTransfer(majoraPacket)
case protocol.TypeDisconnect:
go client.handleDisconnectMessage(majoraPacket)
client.handleDisconnectMessage(majoraPacket)
case protocol.TypeControl:
go client.handleControlMessage(majoraPacket)
client.handleControlMessage(majoraPacket)
case protocol.TypeDestroy:
go client.handleDestroyMessage()
client.handleDestroyMessage()
}
}
}
......
......@@ -17,23 +17,23 @@ func (client *Client) handleHeartbeatMessage() {
go func() {
logger.Debug().Msg("receive heartbeat message from nat server")
packet := protocol.TypeHeartbeat.CreatePacket()
encode, _ := client.Codec.Encode(packet)
if err := client.WriteAndFlush(encode); err != nil {
if err := client.WriteAndFlush(packet); err != nil {
logger.Error().Msgf("flush heart beat message error %s", err.Error())
}
}()
}
func (client *Client) handleConnect(packet *protocol.MajoraPacket) {
go func(packet *protocol.MajoraPacket) {
if len(packet.Extra) == 0 {
client.disconnect(packet, "empty extra")
client.closeVirtualConnection(packet, "empty extra")
return
}
hostPort := strings.Split(packet.Extra, ":")
if len(hostPort) != 2 {
client.disconnect(packet, "invalid extra "+packet.Extra)
client.closeVirtualConnection(packet, "invalid extra "+packet.Extra)
return
}
logger.Info().Msgf("handleConnect to %s", hostPort)
......@@ -46,7 +46,7 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket) {
addr := fmt.Sprintf("%s:%s", hostPort[0], hostPort[1])
conn, err = net.DialTimeout(common.TCP, addr, common.ConnTimeout)
if err != nil {
client.disconnect(packet, "connect to target host error "+err.Error())
client.closeVirtualConnection(packet, "connect to target host error "+err.Error())
return
}
......@@ -55,18 +55,19 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket) {
majoraPacket.SerialNumber = packet.SerialNumber
majoraPacket.Extra = client.Options.ClientID
encode, _ := client.Codec.Encode(majoraPacket)
if err := client.WriteAndFlush(encode); err != nil {
if err := client.WriteAndFlush(majoraPacket); err != nil {
logger.Error().Msgf("handleConnect message error %s", err.Error())
_ = conn.Close()
return
}
client.handleConnection(conn, packet)
}(packet)
}
func (client *Client) WriteAndFlush(packet []byte) error {
func (client *Client) WriteAndFlush(packet *protocol.MajoraPacket) error {
writer := bufio.NewWriter(client.natTunnel.Load().(net.Conn))
if _, err := writer.Write(packet); err != nil {
encode := client.Codec.Encode(packet)
if _, err := writer.Write(encode); err != nil {
logger.Warn().Msgf("write to nat server error err:%+v", err)
return err
}
......@@ -74,6 +75,7 @@ func (client *Client) WriteAndFlush(packet []byte) error {
}
func (client *Client) handleTransfer(packet *protocol.MajoraPacket) {
go func(packet *protocol.MajoraPacket) {
conn, ok := client.GetConnection(packet, "handleTransfer")
// 如何把这个错误告诉服务端
if !ok {
......@@ -84,6 +86,7 @@ func (client *Client) handleTransfer(packet *protocol.MajoraPacket) {
logger.Warn().Msgf("write with error cnt=%d|err=%+v", cnt, err)
client.removeConnection(packet, "write_error")
}
}(packet)
}
func (client *Client) handleConnection(conn net.Conn, packet *protocol.MajoraPacket) {
......@@ -109,32 +112,32 @@ func (client *Client) handleConnection(conn net.Conn, packet *protocol.MajoraPac
pack := protocol.TypeTransfer.CreatePacket()
pack.Data = buf
pack.SerialNumber = packet.SerialNumber
encode, _ := client.Codec.Encode(pack)
if err = client.WriteAndFlush(encode); err != nil {
if err = client.WriteAndFlush(pack); err != nil {
logger.Error().Msgf("write to nat server error %+v", err)
}
}
}
func (client *Client) handleDisconnectMessage(packet *protocol.MajoraPacket) {
go func() {
client.removeConnection(packet, "from_server")
}()
}
func (client *Client) handleControlMessage(_ *protocol.MajoraPacket) {
go func() {
logger.Debug().Msg("handleControlMessage ")
}()
}
// handleDestroyMessage 是直接关闭nat server ?
func (client *Client) handleDestroyMessage() {
}
func (client *Client) disconnect(packet *protocol.MajoraPacket, msg string) {
logger.Warn().Msgf("disconnect to server %s", msg)
disconnectCmd := protocol.TypeDisconnect.CreatePacket()
disconnectCmd.SerialNumber = packet.SerialNumber
disconnectCmd.Data = []byte(msg)
encode, _ := client.Codec.Encode(disconnectCmd)
_ = client.WriteAndFlush(encode)
go func() {
conn, ok := client.natTunnel.Load().(net.Conn)
if ok && conn != nil {
_ = conn.Close()
}
}()
}
func (client *Client) AddConnection(packet *protocol.MajoraPacket, conn net.Conn) {
......@@ -160,8 +163,7 @@ func (client *Client) removeConnection(packet *protocol.MajoraPacket, reason str
majoraPacket.SerialNumber = packet.SerialNumber
majoraPacket.Data = []byte(client.Options.ClientID)
encode, _ := client.Codec.Encode(majoraPacket)
if err := client.WriteAndFlush(encode); err != nil {
if err := client.WriteAndFlush(majoraPacket); err != nil {
logger.Warn().Msgf("flush to nat server error %s", err.Error())
}
}
......@@ -179,3 +181,14 @@ func (client *Client) GetConnection(packet *protocol.MajoraPacket, step string)
conn, ok = load.(net.Conn)
return
}
func (client *Client) closeVirtualConnection(packet *protocol.MajoraPacket, msg string) {
logger.Warn().Msgf("disconnect to server %s", msg)
majoraPacket := protocol.TypeDisconnect.CreatePacket()
majoraPacket.SerialNumber = packet.SerialNumber
majoraPacket.Extra = client.Options.ClientID
if err := client.WriteAndFlush(packet); err != nil {
logger.Error().Msgf("closeVirtualConnection with error %+v", err)
}
}
......@@ -13,8 +13,6 @@ const (
TypeSize = 1
ExtraSize = 1
SerialNumberSize = 8
KeyLenSize = 1
ValueLenSize = 1
MAGIC = int64(0x6D616A6F72613031)
)
......@@ -38,9 +36,9 @@ const (
)
var (
NilPacketError = errors.New("packet is nil")
InvalidSizeError = errors.New("invalid size")
InvalidMagicError = errors.New("invalid magic")
ErrNilPacket = errors.New("packet is nil")
ErrInvalidSize = errors.New("invalid size")
ErrInvalidMagic = errors.New("invalid magic")
)
func ConvertInt32ToBytes(input int32) []byte {
......
......@@ -9,6 +9,11 @@ import (
//export NewDefClient
func NewDefClient(asyn C.int, account *C.char) {
options := client.NewOptions()
newAccount := C.GoString(account)
if len(newAccount) > 0 {
options.Account = newAccount
}
cli := &client.Client{Options: options, Codec: protocol.NewDefCodec()}
if int(asyn) > 0 {
......@@ -23,12 +28,18 @@ func NewClientWithNatServer(addr *C.char, clientID *C.char, asyn C.int, account
options := client.NewOptions()
newAddr := C.GoString(addr)
newClientID := C.GoString(clientID)
newAccount := C.GoString(account)
if len(newAddr) > 0 {
options.NatHostPort = newAddr
}
if len(newClientID) > 0 {
options.ClientID = newClientID
}
if len(newAccount) > 0 {
options.Account = newAccount
}
cli := &client.Client{Options: options, Codec: protocol.NewDefCodec()}
if int(asyn) > 0 {
go cli.StartUp()
......
......@@ -10,7 +10,7 @@ var (
type (
ICodec interface {
Encode(packet *MajoraPacket) ([]byte, error)
Encode(packet *MajoraPacket) []byte
Decode(reader *bufio.Reader) (*MajoraPacket, error)
}
......@@ -27,7 +27,7 @@ func NewDefCodec() *DefCodec {
}
}
func (d *DefCodec) Encode(packet *MajoraPacket) ([]byte, error) {
func (d *DefCodec) Encode(packet *MajoraPacket) []byte {
return d.Encoder.Encode(packet)
}
......
......@@ -23,19 +23,19 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket
}
if !common.ReadMagic(magicbs) {
return nil, common.InvalidMagicError
return nil, common.ErrInvalidMagic
}
frameLen, err := common.ReadInt32(reader)
if err != nil {
return nil, common.InvalidSizeError
return nil, common.ErrInvalidSize
}
// type
msgType, err := common.ReadByte(reader)
if err != nil {
logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError
return nil, common.ErrInvalidSize
}
pack = &MajoraPacket{}
pack.Ttype = MajoraPacketType(msgType)
......@@ -44,20 +44,20 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket
pack.SerialNumber, err = common.ReadInt64(reader)
if err != nil {
logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError
return nil, common.ErrInvalidSize
}
// extra size
extraSize, err := common.ReadByte(reader)
if err != nil {
logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError
return nil, common.ErrInvalidSize
}
extra, err := common.ReadN(int(extraSize), reader)
if err != nil {
logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError
return nil, common.ErrInvalidSize
}
pack.Extra = string(extra)
......@@ -65,7 +65,7 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket
dataSize := int(frameLen) - common.TypeSize - common.SerialNumberSize - common.ExtraSize - int(extraSize)
if dataSize < 0 {
logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError
return nil, common.ErrInvalidSize
}
if dataSize > 0 {
......
......@@ -7,15 +7,15 @@ import (
)
type Encoder interface {
Encode(*MajoraPacket) ([]byte, error)
Encode(*MajoraPacket) []byte
}
type MajoraPacketEncoder struct {
}
func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) ([]byte, error) {
func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) []byte {
if packet == nil {
return nil, common.NilPacketError
return nil
}
bodyLength := common.TypeSize + common.SerialNumberSize + common.ExtraSize
......@@ -27,7 +27,6 @@ func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) ([]byte, error) {
innerBuf = make([]byte, 0, bodyLength+8+4)
// todo 池化提高性能
buffer = bytes.NewBuffer(innerBuf)
err error
)
// magic 8byte
......@@ -51,5 +50,5 @@ func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) ([]byte, error) {
if len(packet.Data) > 0 {
buffer.Write(packet.Data)
}
return buffer.Bytes(), err
return buffer.Bytes()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment