Commit a7629a8c authored by wei.xuan's avatar wei.xuan

feat: opt

parent bab9c2e0
...@@ -55,8 +55,7 @@ func (client *Client) register() { ...@@ -55,8 +55,7 @@ func (client *Client) register() {
packet := protocol.TypeRegister.CreatePacket() packet := protocol.TypeRegister.CreatePacket()
packet.Extra = client.Options.ClientID packet.Extra = client.Options.ClientID
packet.Data = protocol.EncodeExtra(client.Options.ExtraInfo) packet.Data = protocol.EncodeExtra(client.Options.ExtraInfo)
encode, _ := client.Codec.Encode(packet) if err := client.WriteAndFlush(packet); err != nil {
if err := client.WriteAndFlush(encode); err != nil {
logger.Error().Msgf("register to nat server with error %s", err.Error()) logger.Error().Msgf("register to nat server with error %s", err.Error())
} else { } else {
logger.Info().Msg("register to nat server success") logger.Info().Msg("register to nat server success")
...@@ -68,7 +67,6 @@ func (client *Client) handleNatEvent() { ...@@ -68,7 +67,6 @@ func (client *Client) handleNatEvent() {
reader := bufio.NewReader(client.natTunnel.Load().(net.Conn)) reader := bufio.NewReader(client.natTunnel.Load().(net.Conn))
majoraPacket, err := client.Codec.Decode(reader) majoraPacket, err := client.Codec.Decode(reader)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
logger.Error().Msgf("*********disconnect******")
client.reConnect() client.reConnect()
continue continue
} }
...@@ -83,15 +81,15 @@ func (client *Client) handleNatEvent() { ...@@ -83,15 +81,15 @@ func (client *Client) handleNatEvent() {
case protocol.TypeHeartbeat: case protocol.TypeHeartbeat:
client.handleHeartbeatMessage() client.handleHeartbeatMessage()
case protocol.TypeConnect: case protocol.TypeConnect:
go client.handleConnect(majoraPacket) client.handleConnect(majoraPacket)
case protocol.TypeTransfer: case protocol.TypeTransfer:
go client.handleTransfer(majoraPacket) client.handleTransfer(majoraPacket)
case protocol.TypeDisconnect: case protocol.TypeDisconnect:
go client.handleDisconnectMessage(majoraPacket) client.handleDisconnectMessage(majoraPacket)
case protocol.TypeControl: case protocol.TypeControl:
go client.handleControlMessage(majoraPacket) client.handleControlMessage(majoraPacket)
case protocol.TypeDestroy: case protocol.TypeDestroy:
go client.handleDestroyMessage() client.handleDestroyMessage()
} }
} }
} }
......
...@@ -17,23 +17,23 @@ func (client *Client) handleHeartbeatMessage() { ...@@ -17,23 +17,23 @@ func (client *Client) handleHeartbeatMessage() {
go func() { go func() {
logger.Debug().Msg("receive heartbeat message from nat server") logger.Debug().Msg("receive heartbeat message from nat server")
packet := protocol.TypeHeartbeat.CreatePacket() packet := protocol.TypeHeartbeat.CreatePacket()
encode, _ := client.Codec.Encode(packet) if err := client.WriteAndFlush(packet); err != nil {
if err := client.WriteAndFlush(encode); err != nil {
logger.Error().Msgf("flush heart beat message error %s", err.Error()) logger.Error().Msgf("flush heart beat message error %s", err.Error())
} }
}() }()
} }
func (client *Client) handleConnect(packet *protocol.MajoraPacket) { func (client *Client) handleConnect(packet *protocol.MajoraPacket) {
go func(packet *protocol.MajoraPacket) {
if len(packet.Extra) == 0 { if len(packet.Extra) == 0 {
client.disconnect(packet, "empty extra") client.closeVirtualConnection(packet, "empty extra")
return return
} }
hostPort := strings.Split(packet.Extra, ":") hostPort := strings.Split(packet.Extra, ":")
if len(hostPort) != 2 { if len(hostPort) != 2 {
client.disconnect(packet, "invalid extra "+packet.Extra) client.closeVirtualConnection(packet, "invalid extra "+packet.Extra)
return return
} }
logger.Info().Msgf("handleConnect to %s", hostPort) logger.Info().Msgf("handleConnect to %s", hostPort)
...@@ -46,7 +46,7 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket) { ...@@ -46,7 +46,7 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket) {
addr := fmt.Sprintf("%s:%s", hostPort[0], hostPort[1]) addr := fmt.Sprintf("%s:%s", hostPort[0], hostPort[1])
conn, err = net.DialTimeout(common.TCP, addr, common.ConnTimeout) conn, err = net.DialTimeout(common.TCP, addr, common.ConnTimeout)
if err != nil { if err != nil {
client.disconnect(packet, "connect to target host error "+err.Error()) client.closeVirtualConnection(packet, "connect to target host error "+err.Error())
return return
} }
...@@ -55,18 +55,19 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket) { ...@@ -55,18 +55,19 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket) {
majoraPacket.SerialNumber = packet.SerialNumber majoraPacket.SerialNumber = packet.SerialNumber
majoraPacket.Extra = client.Options.ClientID majoraPacket.Extra = client.Options.ClientID
encode, _ := client.Codec.Encode(majoraPacket) if err := client.WriteAndFlush(majoraPacket); err != nil {
if err := client.WriteAndFlush(encode); err != nil {
logger.Error().Msgf("handleConnect message error %s", err.Error()) logger.Error().Msgf("handleConnect message error %s", err.Error())
_ = conn.Close() _ = conn.Close()
return return
} }
client.handleConnection(conn, packet) client.handleConnection(conn, packet)
}(packet)
} }
func (client *Client) WriteAndFlush(packet []byte) error { func (client *Client) WriteAndFlush(packet *protocol.MajoraPacket) error {
writer := bufio.NewWriter(client.natTunnel.Load().(net.Conn)) writer := bufio.NewWriter(client.natTunnel.Load().(net.Conn))
if _, err := writer.Write(packet); err != nil { encode := client.Codec.Encode(packet)
if _, err := writer.Write(encode); err != nil {
logger.Warn().Msgf("write to nat server error err:%+v", err) logger.Warn().Msgf("write to nat server error err:%+v", err)
return err return err
} }
...@@ -74,6 +75,7 @@ func (client *Client) WriteAndFlush(packet []byte) error { ...@@ -74,6 +75,7 @@ func (client *Client) WriteAndFlush(packet []byte) error {
} }
func (client *Client) handleTransfer(packet *protocol.MajoraPacket) { func (client *Client) handleTransfer(packet *protocol.MajoraPacket) {
go func(packet *protocol.MajoraPacket) {
conn, ok := client.GetConnection(packet, "handleTransfer") conn, ok := client.GetConnection(packet, "handleTransfer")
// 如何把这个错误告诉服务端 // 如何把这个错误告诉服务端
if !ok { if !ok {
...@@ -84,6 +86,7 @@ func (client *Client) handleTransfer(packet *protocol.MajoraPacket) { ...@@ -84,6 +86,7 @@ func (client *Client) handleTransfer(packet *protocol.MajoraPacket) {
logger.Warn().Msgf("write with error cnt=%d|err=%+v", cnt, err) logger.Warn().Msgf("write with error cnt=%d|err=%+v", cnt, err)
client.removeConnection(packet, "write_error") client.removeConnection(packet, "write_error")
} }
}(packet)
} }
func (client *Client) handleConnection(conn net.Conn, packet *protocol.MajoraPacket) { func (client *Client) handleConnection(conn net.Conn, packet *protocol.MajoraPacket) {
...@@ -109,32 +112,32 @@ func (client *Client) handleConnection(conn net.Conn, packet *protocol.MajoraPac ...@@ -109,32 +112,32 @@ func (client *Client) handleConnection(conn net.Conn, packet *protocol.MajoraPac
pack := protocol.TypeTransfer.CreatePacket() pack := protocol.TypeTransfer.CreatePacket()
pack.Data = buf pack.Data = buf
pack.SerialNumber = packet.SerialNumber pack.SerialNumber = packet.SerialNumber
encode, _ := client.Codec.Encode(pack) if err = client.WriteAndFlush(pack); err != nil {
if err = client.WriteAndFlush(encode); err != nil {
logger.Error().Msgf("write to nat server error %+v", err) logger.Error().Msgf("write to nat server error %+v", err)
} }
} }
} }
func (client *Client) handleDisconnectMessage(packet *protocol.MajoraPacket) { func (client *Client) handleDisconnectMessage(packet *protocol.MajoraPacket) {
go func() {
client.removeConnection(packet, "from_server") client.removeConnection(packet, "from_server")
}()
} }
func (client *Client) handleControlMessage(_ *protocol.MajoraPacket) { func (client *Client) handleControlMessage(_ *protocol.MajoraPacket) {
go func() {
logger.Debug().Msg("handleControlMessage ") logger.Debug().Msg("handleControlMessage ")
}()
} }
// handleDestroyMessage 是直接关闭nat server ?
func (client *Client) handleDestroyMessage() { func (client *Client) handleDestroyMessage() {
go func() {
} conn, ok := client.natTunnel.Load().(net.Conn)
if ok && conn != nil {
func (client *Client) disconnect(packet *protocol.MajoraPacket, msg string) { _ = conn.Close()
logger.Warn().Msgf("disconnect to server %s", msg) }
disconnectCmd := protocol.TypeDisconnect.CreatePacket() }()
disconnectCmd.SerialNumber = packet.SerialNumber
disconnectCmd.Data = []byte(msg)
encode, _ := client.Codec.Encode(disconnectCmd)
_ = client.WriteAndFlush(encode)
} }
func (client *Client) AddConnection(packet *protocol.MajoraPacket, conn net.Conn) { func (client *Client) AddConnection(packet *protocol.MajoraPacket, conn net.Conn) {
...@@ -160,8 +163,7 @@ func (client *Client) removeConnection(packet *protocol.MajoraPacket, reason str ...@@ -160,8 +163,7 @@ func (client *Client) removeConnection(packet *protocol.MajoraPacket, reason str
majoraPacket.SerialNumber = packet.SerialNumber majoraPacket.SerialNumber = packet.SerialNumber
majoraPacket.Data = []byte(client.Options.ClientID) majoraPacket.Data = []byte(client.Options.ClientID)
encode, _ := client.Codec.Encode(majoraPacket) if err := client.WriteAndFlush(majoraPacket); err != nil {
if err := client.WriteAndFlush(encode); err != nil {
logger.Warn().Msgf("flush to nat server error %s", err.Error()) logger.Warn().Msgf("flush to nat server error %s", err.Error())
} }
} }
...@@ -179,3 +181,14 @@ func (client *Client) GetConnection(packet *protocol.MajoraPacket, step string) ...@@ -179,3 +181,14 @@ func (client *Client) GetConnection(packet *protocol.MajoraPacket, step string)
conn, ok = load.(net.Conn) conn, ok = load.(net.Conn)
return return
} }
func (client *Client) closeVirtualConnection(packet *protocol.MajoraPacket, msg string) {
logger.Warn().Msgf("disconnect to server %s", msg)
majoraPacket := protocol.TypeDisconnect.CreatePacket()
majoraPacket.SerialNumber = packet.SerialNumber
majoraPacket.Extra = client.Options.ClientID
if err := client.WriteAndFlush(packet); err != nil {
logger.Error().Msgf("closeVirtualConnection with error %+v", err)
}
}
...@@ -13,8 +13,6 @@ const ( ...@@ -13,8 +13,6 @@ const (
TypeSize = 1 TypeSize = 1
ExtraSize = 1 ExtraSize = 1
SerialNumberSize = 8 SerialNumberSize = 8
KeyLenSize = 1
ValueLenSize = 1
MAGIC = int64(0x6D616A6F72613031) MAGIC = int64(0x6D616A6F72613031)
) )
...@@ -38,9 +36,9 @@ const ( ...@@ -38,9 +36,9 @@ const (
) )
var ( var (
NilPacketError = errors.New("packet is nil") ErrNilPacket = errors.New("packet is nil")
InvalidSizeError = errors.New("invalid size") ErrInvalidSize = errors.New("invalid size")
InvalidMagicError = errors.New("invalid magic") ErrInvalidMagic = errors.New("invalid magic")
) )
func ConvertInt32ToBytes(input int32) []byte { func ConvertInt32ToBytes(input int32) []byte {
......
...@@ -9,6 +9,11 @@ import ( ...@@ -9,6 +9,11 @@ import (
//export NewDefClient //export NewDefClient
func NewDefClient(asyn C.int, account *C.char) { func NewDefClient(asyn C.int, account *C.char) {
options := client.NewOptions() options := client.NewOptions()
newAccount := C.GoString(account)
if len(newAccount) > 0 {
options.Account = newAccount
}
cli := &client.Client{Options: options, Codec: protocol.NewDefCodec()} cli := &client.Client{Options: options, Codec: protocol.NewDefCodec()}
if int(asyn) > 0 { if int(asyn) > 0 {
...@@ -23,12 +28,18 @@ func NewClientWithNatServer(addr *C.char, clientID *C.char, asyn C.int, account ...@@ -23,12 +28,18 @@ func NewClientWithNatServer(addr *C.char, clientID *C.char, asyn C.int, account
options := client.NewOptions() options := client.NewOptions()
newAddr := C.GoString(addr) newAddr := C.GoString(addr)
newClientID := C.GoString(clientID) newClientID := C.GoString(clientID)
newAccount := C.GoString(account)
if len(newAddr) > 0 { if len(newAddr) > 0 {
options.NatHostPort = newAddr options.NatHostPort = newAddr
} }
if len(newClientID) > 0 { if len(newClientID) > 0 {
options.ClientID = newClientID options.ClientID = newClientID
} }
if len(newAccount) > 0 {
options.Account = newAccount
}
cli := &client.Client{Options: options, Codec: protocol.NewDefCodec()} cli := &client.Client{Options: options, Codec: protocol.NewDefCodec()}
if int(asyn) > 0 { if int(asyn) > 0 {
go cli.StartUp() go cli.StartUp()
......
...@@ -10,7 +10,7 @@ var ( ...@@ -10,7 +10,7 @@ var (
type ( type (
ICodec interface { ICodec interface {
Encode(packet *MajoraPacket) ([]byte, error) Encode(packet *MajoraPacket) []byte
Decode(reader *bufio.Reader) (*MajoraPacket, error) Decode(reader *bufio.Reader) (*MajoraPacket, error)
} }
...@@ -27,7 +27,7 @@ func NewDefCodec() *DefCodec { ...@@ -27,7 +27,7 @@ func NewDefCodec() *DefCodec {
} }
} }
func (d *DefCodec) Encode(packet *MajoraPacket) ([]byte, error) { func (d *DefCodec) Encode(packet *MajoraPacket) []byte {
return d.Encoder.Encode(packet) return d.Encoder.Encode(packet)
} }
......
...@@ -23,19 +23,19 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket ...@@ -23,19 +23,19 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket
} }
if !common.ReadMagic(magicbs) { if !common.ReadMagic(magicbs) {
return nil, common.InvalidMagicError return nil, common.ErrInvalidMagic
} }
frameLen, err := common.ReadInt32(reader) frameLen, err := common.ReadInt32(reader)
if err != nil { if err != nil {
return nil, common.InvalidSizeError return nil, common.ErrInvalidSize
} }
// type // type
msgType, err := common.ReadByte(reader) msgType, err := common.ReadByte(reader)
if err != nil { if err != nil {
logger.Error().Msgf("read type error %+v", err) logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError return nil, common.ErrInvalidSize
} }
pack = &MajoraPacket{} pack = &MajoraPacket{}
pack.Ttype = MajoraPacketType(msgType) pack.Ttype = MajoraPacketType(msgType)
...@@ -44,20 +44,20 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket ...@@ -44,20 +44,20 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket
pack.SerialNumber, err = common.ReadInt64(reader) pack.SerialNumber, err = common.ReadInt64(reader)
if err != nil { if err != nil {
logger.Error().Msgf("read type error %+v", err) logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError return nil, common.ErrInvalidSize
} }
// extra size // extra size
extraSize, err := common.ReadByte(reader) extraSize, err := common.ReadByte(reader)
if err != nil { if err != nil {
logger.Error().Msgf("read type error %+v", err) logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError return nil, common.ErrInvalidSize
} }
extra, err := common.ReadN(int(extraSize), reader) extra, err := common.ReadN(int(extraSize), reader)
if err != nil { if err != nil {
logger.Error().Msgf("read type error %+v", err) logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError return nil, common.ErrInvalidSize
} }
pack.Extra = string(extra) pack.Extra = string(extra)
...@@ -65,7 +65,7 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket ...@@ -65,7 +65,7 @@ func (mpd *MajoraPacketDecoder) Decode(reader *bufio.Reader) (pack *MajoraPacket
dataSize := int(frameLen) - common.TypeSize - common.SerialNumberSize - common.ExtraSize - int(extraSize) dataSize := int(frameLen) - common.TypeSize - common.SerialNumberSize - common.ExtraSize - int(extraSize)
if dataSize < 0 { if dataSize < 0 {
logger.Error().Msgf("read type error %+v", err) logger.Error().Msgf("read type error %+v", err)
return nil, common.InvalidSizeError return nil, common.ErrInvalidSize
} }
if dataSize > 0 { if dataSize > 0 {
......
...@@ -7,15 +7,15 @@ import ( ...@@ -7,15 +7,15 @@ import (
) )
type Encoder interface { type Encoder interface {
Encode(*MajoraPacket) ([]byte, error) Encode(*MajoraPacket) []byte
} }
type MajoraPacketEncoder struct { type MajoraPacketEncoder struct {
} }
func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) ([]byte, error) { func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) []byte {
if packet == nil { if packet == nil {
return nil, common.NilPacketError return nil
} }
bodyLength := common.TypeSize + common.SerialNumberSize + common.ExtraSize bodyLength := common.TypeSize + common.SerialNumberSize + common.ExtraSize
...@@ -27,7 +27,6 @@ func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) ([]byte, error) { ...@@ -27,7 +27,6 @@ func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) ([]byte, error) {
innerBuf = make([]byte, 0, bodyLength+8+4) innerBuf = make([]byte, 0, bodyLength+8+4)
// todo 池化提高性能 // todo 池化提高性能
buffer = bytes.NewBuffer(innerBuf) buffer = bytes.NewBuffer(innerBuf)
err error
) )
// magic 8byte // magic 8byte
...@@ -51,5 +50,5 @@ func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) ([]byte, error) { ...@@ -51,5 +50,5 @@ func (s *MajoraPacketEncoder) Encode(packet *MajoraPacket) ([]byte, error) {
if len(packet.Data) > 0 { if len(packet.Data) > 0 {
buffer.Write(packet.Data) buffer.Write(packet.Data)
} }
return buffer.Bytes(), err return buffer.Bytes()
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment