Unverified Commit 7886bc27 authored by 望哥's avatar 望哥 Committed by GitHub

Merge pull request #6 from divebomb/master

Add: task worker pool
parents 8b0e100a 58156632
......@@ -14,6 +14,10 @@
## develop history ##
---
- 2019/06/08
> Improvement
* add task worker pool
- 2019/06/07
> Improvement
* use time.After instead of wheel.After
......
......@@ -350,7 +350,6 @@ func (c *client) connect() {
}
err = c.newSession(ss)
if err == nil {
// ss.RunEventLoop()
ss.(*session).run()
c.Lock()
if c.ssMap == nil {
......@@ -421,6 +420,7 @@ func (c *client) stop() {
s.Close()
}
c.ssMap = nil
c.Unlock()
})
}
......
......@@ -13,7 +13,9 @@ import (
"compress/flate"
"net"
"time"
)
import (
perrors "github.com/pkg/errors"
)
......@@ -152,6 +154,7 @@ type Session interface {
SetRQLen(int)
SetWQLen(int)
SetWaitTime(time.Duration)
SetTaskPool(*TaskPool)
GetAttribute(interface{}) interface{}
SetAttribute(interface{}, interface{})
......
......@@ -10,6 +10,10 @@
package getty
import (
"fmt"
)
/////////////////////////////////////////
// Server Options
/////////////////////////////////////////
......@@ -105,3 +109,54 @@ func WithRootCertificateFile(cert string) ClientOption {
o.cert = cert
}
}
/////////////////////////////////////////
// Task Pool Options
/////////////////////////////////////////
type TaskPoolOptions struct {
tQLen int // task queue length
tQNumber int // task queue number
tQPoolSize int // task pool size
}
func (o *TaskPoolOptions) validate() {
if o.tQPoolSize < 1 {
panic(fmt.Sprintf("[getty][task_pool] illegal pool size %d", o.tQPoolSize))
}
if o.tQLen < 1 {
o.tQLen = defaultTaskQLen
}
if o.tQNumber < 1 {
o.tQNumber = defaultTaskQNumber
}
if o.tQNumber > o.tQPoolSize {
o.tQNumber = o.tQPoolSize
}
}
type TaskPoolOption func(*TaskPoolOptions)
// @size is the task queue pool size
func WithTaskPoolTaskPoolSize(size int) TaskPoolOption {
return func(o *TaskPoolOptions) {
o.tQPoolSize = size
}
}
// @length is the task queue length
func WithTaskPoolTaskQueueLength(length int) TaskPoolOption {
return func(o *TaskPoolOptions) {
o.tQLen = length
}
}
// @number is the task queue number
func WithTaskPoolTaskQueueNumber(number int) TaskPoolOption {
return func(o *TaskPoolOptions) {
o.tQNumber = number
}
}
......@@ -252,7 +252,6 @@ func (s *server) runTcpEventLoop(newSession NewSessionCallback) {
continue
}
delay = 0
// client.RunEventLoop()
client.(*session).run()
}
}()
......@@ -323,7 +322,6 @@ func (s *wsHandler) serveWSRequest(w http.ResponseWriter, r *http.Request) {
if ss.(*session).maxMsgLen > 0 {
conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
}
// ss.RunEventLoop()
ss.(*session).run()
}
......
......@@ -29,6 +29,7 @@ const (
netIOTimeout = 1e9 // 1s
period = 60 * 1e9 // 1 minute
pendingDuration = 3e9
defaultQLen = 1024
defaultSessionName = "session"
defaultTCPSessionName = "tcp-session"
defaultUDPSessionName = "udp-session"
......@@ -43,26 +44,38 @@ const (
// getty base session
type session struct {
name string
endPoint EndPoint
maxMsgLen int32
name string
endPoint EndPoint
// net read Write
Connection
// pkgHandler ReadWriter
reader Reader // @reader should be nil when @conn is a gettyWSConn object.
writer Writer
listener EventListener
once sync.Once
done chan struct{}
// errFlag bool
// codec
reader Reader // @reader should be nil when @conn is a gettyWSConn object.
writer Writer
// read & write
rQ chan interface{}
wQ chan interface{}
// handle logic
maxMsgLen int32
// task queue
tPool *TaskPool
// heartbeat
period time.Duration
wait time.Duration
rQ chan interface{}
wQ chan interface{}
// done
wait time.Duration
once sync.Once
done chan struct{}
// attribute
attrs *ValuesContext
// goroutines sync
grNum int32
lock sync.RWMutex
......@@ -70,14 +83,18 @@ type session struct {
func newSession(endPoint EndPoint, conn Connection) *session {
ss := &session{
name: defaultSessionName,
endPoint: endPoint,
maxMsgLen: maxReadBufLen,
name: defaultSessionName,
endPoint: endPoint,
Connection: conn,
done: make(chan struct{}),
period: period,
wait: pendingDuration,
attrs: NewValuesContext(nil),
maxMsgLen: maxReadBufLen,
period: period,
done: make(chan struct{}),
wait: pendingDuration,
attrs: NewValuesContext(nil),
}
ss.Connection.setSession(ss)
......@@ -115,7 +132,6 @@ func (s *session) Reset() {
s.name = defaultSessionName
s.once = sync.Once{}
s.done = make(chan struct{})
// s.errFlag = false
s.period = period
s.wait = pendingDuration
s.attrs = NewValuesContext(nil)
......@@ -190,30 +206,51 @@ func (s *session) IsClosed() bool {
}
// set maximum pacakge length of every pacakge in (EventListener)OnMessage(@pkgs)
func (s *session) SetMaxMsgLen(length int) { s.maxMsgLen = int32(length) }
func (s *session) SetMaxMsgLen(length int) {
s.lock.Lock()
defer s.lock.Unlock()
s.maxMsgLen = int32(length)
}
// set session name
func (s *session) SetName(name string) { s.name = name }
func (s *session) SetName(name string) {
s.lock.Lock()
defer s.lock.Unlock()
s.name = name
}
// set EventListener
func (s *session) SetEventListener(listener EventListener) {
s.lock.Lock()
defer s.lock.Unlock()
s.listener = listener
}
// set package handler
func (s *session) SetPkgHandler(handler ReadWriter) {
s.lock.Lock()
defer s.lock.Unlock()
s.reader = handler
s.writer = handler
// s.pkgHandler = handler
}
// set Reader
func (s *session) SetReader(reader Reader) {
s.lock.Lock()
defer s.lock.Unlock()
s.reader = reader
}
// set Writer
func (s *session) SetWriter(writer Writer) {
s.lock.Lock()
defer s.lock.Unlock()
s.writer = writer
}
......@@ -224,8 +261,8 @@ func (s *session) SetCronPeriod(period int) {
}
s.lock.Lock()
defer s.lock.Unlock()
s.period = time.Duration(period) * time.Millisecond
s.lock.Unlock()
}
// set @session's read queue size
......@@ -235,9 +272,9 @@ func (s *session) SetRQLen(readQLen int) {
}
s.lock.Lock()
defer s.lock.Unlock()
s.rQ = make(chan interface{}, readQLen)
s.lock.Unlock()
log.Debugf("%s, [session.SetRQLen] rQ{len:%d, cap:%d}", s.Stat(), len(s.rQ), cap(s.rQ))
log.Debug("%s, [session.SetRQLen] rQ{len:%d, cap:%d}", s.Stat(), len(s.rQ), cap(s.rQ))
}
// set @session's Write queue size
......@@ -247,9 +284,9 @@ func (s *session) SetWQLen(writeQLen int) {
}
s.lock.Lock()
defer s.lock.Unlock()
s.wQ = make(chan interface{}, writeQLen)
s.lock.Unlock()
log.Debugf("%s, [session.SetWQLen] wQ{len:%d, cap:%d}", s.Stat(), len(s.wQ), cap(s.wQ))
log.Debug("%s, [session.SetWQLen] wQ{len:%d, cap:%d}", s.Stat(), len(s.wQ), cap(s.wQ))
}
// set maximum wait time when session got error or got exit signal
......@@ -259,8 +296,16 @@ func (s *session) SetWaitTime(waitTime time.Duration) {
}
s.lock.Lock()
defer s.lock.Unlock()
s.wait = waitTime
s.lock.Unlock()
}
// set task pool
func (s *session) SetTaskPool(p *TaskPool) {
s.lock.Lock()
defer s.lock.Unlock()
s.tPool = p
}
// set attribute of key @session:key
......@@ -400,12 +445,6 @@ func (s *session) WriteBytesArray(pkgs ...[]byte) error {
// func (s *session) RunEventLoop() {
func (s *session) run() {
if s.rQ == nil || s.wQ == nil {
errStr := fmt.Sprintf("session{name:%s, rQ:%#v, wQ:%#v}",
s.name, s.rQ, s.wQ)
log.Error(errStr)
panic(errStr)
}
if s.Connection == nil || s.listener == nil || s.writer == nil {
errStr := fmt.Sprintf("session{name:%s, conn:%#v, listener:%#v, writer:%#v}",
s.name, s.Connection, s.listener, s.writer)
......@@ -413,6 +452,14 @@ func (s *session) run() {
panic(errStr)
}
if s.wQ == nil {
s.wQ = make(chan interface{}, defaultQLen)
}
if s.rQ == nil && s.tPool == nil {
s.rQ = make(chan interface{}, defaultQLen)
}
// call session opened
s.UpdateActive()
if err := s.listener.OnOpen(s); err != nil {
......@@ -420,6 +467,7 @@ func (s *session) run() {
return
}
// start read/write gr
atomic.AddInt32(&(s.grNum), 2)
go s.handleLoop()
go s.handlePackage()
......@@ -448,10 +496,8 @@ func (s *session) handleLoop() {
}
grNum = atomic.AddInt32(&(s.grNum), -1)
// if !s.errFlag {
s.listener.OnClose(s)
// }
log.Infof("%s, [session.handleLoop] goroutine exit now, left gr num %d", s.Stat(), grNum)
log.Info("%s, [session.handleLoop] goroutine exit now, left gr num %d", s.Stat(), grNum)
s.gc()
}()
......@@ -516,6 +562,14 @@ LOOP:
}
}
func (s *session) addTask(pkg interface{}) {
if s.tPool != nil {
s.tPool.AddTask(task{session: s, pkg: pkg})
} else {
s.rQ <- pkg
}
}
func (s *session) handlePackage() {
var (
err error
......@@ -593,8 +647,6 @@ func (s *session) handleTCPPackage() error {
break
}
log.Errorf("%s, [session.conn.read] = error:%+v", s.sessionToken(), err)
// for (Codec)OnErr
// s.errFlag = true
exit = true
}
break
......@@ -619,8 +671,6 @@ func (s *session) handleTCPPackage() error {
if err != nil {
log.Warnf("%s, [session.handleTCPPackage] = len{%d}, error:%+v",
s.sessionToken(), pkgLen, err)
// for (Codec)OnErr
// s.errFlag = true
exit = true
break
}
......@@ -630,7 +680,7 @@ func (s *session) handleTCPPackage() error {
}
// handle case 3
s.UpdateActive()
s.rQ <- pkg
s.addTask(pkg)
pktBuf.Next(pkgLen)
// continue to handle case 4
}
......@@ -705,7 +755,7 @@ func (s *session) handleUDPPackage() error {
}
s.UpdateActive()
s.rQ <- UDPContext{Pkg: pkg, PeerAddr: addr}
s.addTask(UDPContext{Pkg: pkg, PeerAddr: addr})
}
return perrors.WithStack(err)
......@@ -735,7 +785,6 @@ func (s *session) handleWSPackage() error {
if err != nil {
log.Warnf("%s, [session.handleWSPackage] = error{%+s}",
s.sessionToken(), err)
// s.errFlag = true
return perrors.WithStack(err)
}
s.UpdateActive()
......@@ -749,9 +798,10 @@ func (s *session) handleWSPackage() error {
s.sessionToken(), length, err)
continue
}
s.rQ <- unmarshalPkg
s.addTask(unmarshalPkg)
} else {
s.rQ <- pkg
s.addTask(pkg)
}
}
......@@ -784,10 +834,14 @@ func (s *session) gc() {
s.lock.Lock()
if s.attrs != nil {
s.attrs = nil
close(s.wQ)
s.wQ = nil
close(s.rQ)
s.rQ = nil
if s.wQ != nil {
close(s.wQ)
s.wQ = nil
}
if s.rQ != nil {
close(s.rQ)
s.rQ = nil
}
s.Connection.close((int)((int64)(s.wait)))
}
s.lock.Unlock()
......@@ -797,5 +851,6 @@ func (s *session) gc() {
// or (session)handleLoop automatically. It's thread safe.
func (s *session) Close() {
s.stop()
log.Infof("%s closed now. its current gr num is %d", s.sessionToken(), atomic.LoadInt32(&(s.grNum)))
log.Info("%s closed now. its current gr num is %d",
s.sessionToken(), atomic.LoadInt32(&(s.grNum)))
}
package getty
import (
"sync"
"sync/atomic"
)
const (
defaultTaskQNumber = 10
defaultTaskQLen = 128
)
// task t
type task struct {
session *session
pkg interface{}
}
// task pool: manage task ts
type TaskPool struct {
TaskPoolOptions
idx uint32 // round robin index
qArray []chan task
wg sync.WaitGroup
once sync.Once
done chan struct{}
}
// build a task pool
func NewTaskPool(opts ...TaskPoolOption) *TaskPool {
var tOpts TaskPoolOptions
for _, opt := range opts {
opt(&tOpts)
}
tOpts.validate()
p := &TaskPool{
TaskPoolOptions: tOpts,
qArray: make([]chan task, tOpts.tQNumber),
done: make(chan struct{}),
}
for i := 0; i < p.tQNumber; i++ {
p.qArray[i] = make(chan task, p.tQLen)
}
return p
}
// start task pool
func (p *TaskPool) start() {
for i := 0; i < p.tQPoolSize; i++ {
p.wg.Add(1)
workerID := i
q := p.qArray[workerID%p.tQNumber]
go p.run(int(workerID), q)
}
}
// worker
func (p *TaskPool) run(id int, q chan task) {
defer p.wg.Done()
var (
ok bool
t task
)
for {
select {
case <-p.done:
if 0 < len(q) {
log.Warn("[getty][task_pool] task worker %d exit now while its task buffer length %d is greater than 0",
id, len(q))
} else {
log.Info("[getty][task_pool] task worker %d exit now", id)
}
return
case t, ok = <-q:
if ok {
t.session.listener.OnMessage(t.session, t.pkg)
}
}
}
}
// add task
func (p *TaskPool) AddTask(t task) {
id := atomic.AddUint32(&p.idx, 1) % uint32(p.tQNumber)
select {
case <-p.done:
return
case p.qArray[id] <- t:
}
}
// stop all tasks
func (p *TaskPool) stop() {
select {
case <-p.done:
return
default:
p.once.Do(func() {
close(p.done)
})
}
}
// check whether the session has been closed.
func (p *TaskPool) IsClosed() bool {
select {
case <-p.done:
return true
default:
return false
}
}
func (p *TaskPool) Close() {
p.stop()
p.wg.Wait()
for i := range p.qArray {
close(p.qArray[i])
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment