Unverified Commit 30e62643 authored by Xin.Zh's avatar Xin.Zh Committed by GitHub

Merge pull request #58 from cvictory/fix/support_return_length

Ftr: return the total length and write length
parents 3eb8b38b 0e5e138f
...@@ -170,8 +170,10 @@ func TestTCPClient(t *testing.T) { ...@@ -170,8 +170,10 @@ func TestTCPClient(t *testing.T) {
func TestUDPClient(t *testing.T) { func TestUDPClient(t *testing.T) {
var ( var (
err error err error
conn *net.UDPConn conn *net.UDPConn
sendLen int
totalLen int
) )
func() { func() {
ip := net.ParseIP("127.0.0.1") ip := net.ParseIP("127.0.0.1")
...@@ -205,10 +207,14 @@ func TestUDPClient(t *testing.T) { ...@@ -205,10 +207,14 @@ func TestUDPClient(t *testing.T) {
assert.Equal(t, 1, msgHandler.SessionNumber()) assert.Equal(t, 1, msgHandler.SessionNumber())
ss := msgHandler.array[0] ss := msgHandler.array[0]
err = ss.WritePkg(nil, 0) totalLen, sendLen, err = ss.WritePkg(nil, 0)
assert.NotNil(t, err) assert.NotNil(t, err)
err = ss.WritePkg([]byte("hello"), 0) assert.True(t, sendLen == 0)
assert.True(t, totalLen == 0)
totalLen, sendLen, err = ss.WritePkg([]byte("hello"), 0)
assert.NotNil(t, perrors.Cause(err)) assert.NotNil(t, perrors.Cause(err))
assert.True(t, sendLen == 0)
assert.True(t, totalLen == 0)
l, err := ss.WriteBytes([]byte("hello")) l, err := ss.WriteBytes([]byte("hello"))
assert.Zero(t, l) assert.Zero(t, l)
assert.NotNil(t, err) assert.NotNil(t, err)
...@@ -240,9 +246,11 @@ func TestUDPClient(t *testing.T) { ...@@ -240,9 +246,11 @@ func TestUDPClient(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
beforeWritePkgNum := atomic.LoadUint32(&udpConn.writePkgNum) beforeWritePkgNum := atomic.LoadUint32(&udpConn.writePkgNum)
err = ss.WritePkg(udpCtx, 0) totalLen, sendLen, err = ss.WritePkg(udpCtx, 0)
assert.Equal(t, beforeWritePkgNum+1, atomic.LoadUint32(&udpConn.writePkgNum)) assert.Equal(t, beforeWritePkgNum+1, atomic.LoadUint32(&udpConn.writePkgNum))
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, sendLen == 0)
assert.True(t, totalLen == 0)
clt.Close() clt.Close()
assert.True(t, clt.IsClosed()) assert.True(t, clt.IsClosed())
......
...@@ -31,7 +31,7 @@ func ClientRequest() { ...@@ -31,7 +31,7 @@ func ClientRequest() {
go func() { go func() {
echoTimes := 10 echoTimes := 10
for i := 0; i < echoTimes; i++ { for i := 0; i < echoTimes; i++ {
err := ss.WritePkg("hello", WritePkgTimeout) _, _, err := ss.WritePkg("hello", WritePkgTimeout)
if err != nil { if err != nil {
log.Infof("session.WritePkg(session{%s}, error{%v}", ss.Stat(), err) log.Infof("session.WritePkg(session{%s}, error{%v}", ss.Stat(), err)
ss.Close() ss.Close()
......
...@@ -171,7 +171,10 @@ type Session interface { ...@@ -171,7 +171,10 @@ type Session interface {
// the Writer will invoke this function. Pls attention that if timeout is less than 0, WritePkg will send @pkg asap. // the Writer will invoke this function. Pls attention that if timeout is less than 0, WritePkg will send @pkg asap.
// for udp session, the first parameter should be UDPContext. // for udp session, the first parameter should be UDPContext.
WritePkg(pkg interface{}, timeout time.Duration) error // totalBytesLength: @pkg stream bytes length after encoding @pkg.
// sendBytesLength: stream bytes length that sent out successfully.
// err: maybe it has illegal data, encoding error, or write out system error.
WritePkg(pkg interface{}, timeout time.Duration) (totalBytesLength int, sendBytesLength int, err error)
WriteBytes([]byte) (int, error) WriteBytes([]byte) (int, error)
WriteBytesArray(...[]byte) (int, error) WriteBytesArray(...[]byte) (int, error)
Close() Close()
......
...@@ -347,12 +347,12 @@ func (s *session) sessionToken() string { ...@@ -347,12 +347,12 @@ func (s *session) sessionToken() string {
s.name, s.EndPoint().EndPointType(), s.ID(), s.LocalAddr(), s.RemoteAddr()) s.name, s.EndPoint().EndPointType(), s.ID(), s.LocalAddr(), s.RemoteAddr())
} }
func (s *session) WritePkg(pkg interface{}, timeout time.Duration) error { func (s *session) WritePkg(pkg interface{}, timeout time.Duration) (int, int, error) {
if pkg == nil { if pkg == nil {
return fmt.Errorf("@pkg is nil") return 0, 0, fmt.Errorf("@pkg is nil")
} }
if s.IsClosed() { if s.IsClosed() {
return ErrSessionClosed return 0, 0, ErrSessionClosed
} }
defer func() { defer func() {
...@@ -367,7 +367,7 @@ func (s *session) WritePkg(pkg interface{}, timeout time.Duration) error { ...@@ -367,7 +367,7 @@ func (s *session) WritePkg(pkg interface{}, timeout time.Duration) error {
pkgBytes, err := s.writer.Write(s, pkg) pkgBytes, err := s.writer.Write(s, pkg)
if err != nil { if err != nil {
log.Warnf("%s, [session.WritePkg] session.writer.Write(@pkg:%#v) = error:%+v", s.Stat(), pkg, err) log.Warnf("%s, [session.WritePkg] session.writer.Write(@pkg:%#v) = error:%+v", s.Stat(), pkg, err)
return perrors.WithStack(err) return len(pkgBytes), 0, perrors.WithStack(err)
} }
var udpCtxPtr *UDPContext var udpCtxPtr *UDPContext
if udpCtx, ok := pkg.(UDPContext); ok { if udpCtx, ok := pkg.(UDPContext); ok {
...@@ -384,13 +384,13 @@ func (s *session) WritePkg(pkg interface{}, timeout time.Duration) error { ...@@ -384,13 +384,13 @@ func (s *session) WritePkg(pkg interface{}, timeout time.Duration) error {
if 0 < timeout { if 0 < timeout {
s.Connection.SetWriteTimeout(timeout) s.Connection.SetWriteTimeout(timeout)
} }
_, err = s.Connection.send(pkg) var succssCount int
succssCount, err = s.Connection.send(pkg)
if err != nil { if err != nil {
log.Warnf("%s, [session.WritePkg] @s.Connection.Write(pkg:%#v) = err:%+v", s.Stat(), pkg, err) log.Warnf("%s, [session.WritePkg] @s.Connection.Write(pkg:%#v) = err:%+v", s.Stat(), pkg, err)
return perrors.WithStack(err) return len(pkgBytes), succssCount, perrors.WithStack(err)
} }
return len(pkgBytes), succssCount, nil
return nil
} }
// for codecs // for codecs
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment