Commit 1e0ac263 authored by AlexStocks's avatar AlexStocks

Imp: just treat task pool as a independent role

parent ac39a3ba
......@@ -51,9 +51,6 @@ type client struct {
newSession NewSessionCallback
ssMap map[Session]struct{}
// task queue pool
tQPool *taskPool
sync.Once
done chan struct{}
wg sync.WaitGroup
......@@ -79,10 +76,6 @@ func newClient(t EndPointType, opts ...ClientOption) *client {
c.ssMap = make(map[Session]struct{}, c.number)
if c.tQPoolSize > 0 {
c.tQPool = newTaskPool(c.taskPoolOptions)
}
return c
}
......@@ -357,7 +350,6 @@ func (c *client) connect() {
}
err = c.newSession(ss)
if err == nil {
ss.(*session).SetTaskPool(c.tQPool)
ss.(*session).run()
c.Lock()
if c.ssMap == nil {
......@@ -429,10 +421,6 @@ func (c *client) stop() {
}
c.ssMap = nil
if c.tQPool != nil {
c.tQPool.close()
c.tQPool = nil
}
c.Unlock()
})
}
......
......@@ -10,6 +10,10 @@
package getty
import (
"fmt"
)
/////////////////////////////////////////
// Server Options
/////////////////////////////////////////
......@@ -19,9 +23,6 @@ type ServerOption func(*ServerOptions)
type ServerOptions struct {
addr string
// task pool
taskPoolOptions
// websocket
path string
cert string
......@@ -36,27 +37,6 @@ func WithLocalAddress(addr string) ServerOption {
}
}
// @size is the task queue pool size
func WithServerTaskPoolSize(size int32) ServerOption {
return func(o *ServerOptions) {
o.taskPoolOptions.tQPoolSize = size
}
}
// @length is the task queue length
func WithServerTaskQueueLength(length int32) ServerOption {
return func(o *ServerOptions) {
o.taskPoolOptions.tQLen = length
}
}
// @number is the task queue number
func WithServerTaskQueueNumber(number int32) ServerOption {
return func(o *ServerOptions) {
o.taskPoolOptions.tQNumber = number
}
}
// @path: websocket request url path
func WithWebsocketServerPath(path string) ServerOption {
return func(o *ServerOptions) {
......@@ -96,9 +76,6 @@ type ClientOptions struct {
number int
reconnectInterval int // reConnect Interval
// task pool
taskPoolOptions
// the cert file of wss server which may contain server domain, server ip, the starting effective date, effective
// duration, the hash alg, the len of the private key.
// wss client will use it.
......@@ -126,30 +103,60 @@ func WithConnectionNumber(num int) ClientOption {
}
}
// @size is the task queue pool size
func WithClientTaskPoolSize(size int32) ClientOption {
// @cert is client certificate file. it can be empty.
func WithRootCertificateFile(cert string) ClientOption {
return func(o *ClientOptions) {
o.taskPoolOptions.tQPoolSize = size
o.cert = cert
}
}
// @length is the task queue length
func WithClientTaskQueueLength(length int32) ClientOption {
return func(o *ClientOptions) {
o.taskPoolOptions.tQLen = length
/////////////////////////////////////////
// Task Pool Options
/////////////////////////////////////////
type TaskPoolOptions struct {
tQLen int // task queue length
tQNumber int // task queue number
tQPoolSize int // task pool size
}
func (o *TaskPoolOptions) validate() {
if o.tQPoolSize == 0 {
panic(fmt.Sprintf("[getty][task_pool] illegal pool size %d", o.tQPoolSize))
}
if o.tQLen == 0 {
o.tQLen = defaultTaskQLen
}
if o.tQNumber < 1 {
o.tQNumber = defaultTaskQNumber
}
if o.tQNumber > o.tQPoolSize {
o.tQNumber = o.tQPoolSize
}
}
// @number is the task queue number
func WithClientTaskQueueNumber(number int32) ClientOption {
return func(o *ClientOptions) {
o.taskPoolOptions.tQNumber = number
type TaskPoolOption func(*TaskPoolOptions)
// @size is the task queue pool size
func WithTaskPoolTaskPoolSize(size int) TaskPoolOption {
return func(o *TaskPoolOptions) {
o.tQPoolSize = size
}
}
// @cert is client certificate file. it can be empty.
func WithRootCertificateFile(cert string) ClientOption {
return func(o *ClientOptions) {
o.cert = cert
// @length is the task queue length
func WithTaskPoolTaskQueueLength(length int) TaskPoolOption {
return func(o *TaskPoolOptions) {
o.tQLen = length
}
}
// @number is the task queue number
func WithTaskPoolTaskQueueNumber(number int) TaskPoolOption {
return func(o *TaskPoolOptions) {
o.tQNumber = number
}
}
......@@ -41,9 +41,6 @@ type server struct {
endPointType EndPointType
server *http.Server // for ws or wss server
// task queue pool
tQPool *taskPool
sync.Once
done chan struct{}
wg sync.WaitGroup
......@@ -67,10 +64,6 @@ func newServer(t EndPointType, opts ...ServerOption) *server {
panic(fmt.Sprintf("@addr:%s", s.addr))
}
if s.tQPoolSize > 0 {
s.tQPool = newTaskPool(s.taskPoolOptions)
}
return s
}
......@@ -137,10 +130,6 @@ func (s *server) stop() {
s.pktListener.Close()
s.pktListener = nil
}
if s.tQPool != nil {
s.tQPool.close()
s.tQPool = nil
}
})
}
}
......@@ -263,7 +252,6 @@ func (s *server) runTcpEventLoop(newSession NewSessionCallback) {
continue
}
delay = 0
client.(*session).SetTaskPool(s.tQPool)
client.(*session).run()
}
}()
......@@ -278,7 +266,6 @@ func (s *server) runUDPEventLoop(newSession NewSessionCallback) {
if err := newSession(ss); err != nil {
panic(err.Error())
}
ss.(*session).SetTaskPool(s.tQPool)
ss.(*session).run()
}
......@@ -335,7 +322,6 @@ func (s *wsHandler) serveWSRequest(w http.ResponseWriter, r *http.Request) {
if ss.(*session).maxMsgLen > 0 {
conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
}
ss.(*session).SetTaskPool(s.server.tQPool)
ss.(*session).run()
}
......
......@@ -63,7 +63,7 @@ type session struct {
// handle logic
maxMsgLen int32
// task queue
tQPool *taskPool
tPool *TaskPool
// heartbeat
period time.Duration
......@@ -89,7 +89,6 @@ func newSession(endPoint EndPoint, conn Connection) *session {
Connection: conn,
maxMsgLen: maxReadBufLen,
tQLen: defaultTaskQLen,
period: period,
......@@ -302,11 +301,11 @@ func (s *session) SetWaitTime(waitTime time.Duration) {
}
// set task pool
func (s *session) SetTaskPool(p *taskPool) {
func (s *session) SetTaskPool(p *TaskPool) {
s.lock.Lock()
defer s.lock.Unlock()
s.tQPool = p
s.tPool = p
}
// set attribute of key @session:key
......@@ -457,7 +456,7 @@ func (s *session) run() {
s.wQ = make(chan interface{}, defaultQLen)
}
if s.rQ == nil && s.tQPool == nil {
if s.rQ == nil && s.tPool == nil {
s.rQ = make(chan interface{}, defaultQLen)
}
......@@ -564,8 +563,8 @@ LOOP:
}
func (s *session) addTask(pkg interface{}) {
if s.tQPool != nil {
s.tQPool.AddTask(task{session: s, pkg: pkg})
if s.tPool != nil {
s.tPool.AddTask(task{session: s, pkg: pkg})
} else {
s.rQ <- pkg
}
......
package getty
import (
"fmt"
"sync"
"sync/atomic"
)
......@@ -17,33 +16,9 @@ type task struct {
pkg interface{}
}
type taskPoolOptions struct {
tQLen int32 // task queue length
tQNumber int32 // task queue number
tQPoolSize int32 // task pool size
}
func (o *taskPoolOptions) Validate() {
if o.tQPoolSize == 0 {
panic(fmt.Sprintf("[getty][task_pool] illegal pool size %d", o.tQPoolSize))
}
if o.tQLen == 0 {
o.tQLen = defaultTaskQLen
}
if o.tQNumber < 1 {
o.tQNumber = defaultTaskQNumber
}
if o.tQNumber > o.tQPoolSize {
o.tQNumber = o.tQPoolSize
}
}
// task pool: manage task ts
type taskPool struct {
taskPoolOptions
type TaskPool struct {
TaskPoolOptions
idx uint32 // round robin index
qArray []chan task
......@@ -54,16 +29,16 @@ type taskPool struct {
}
// build a task pool
func newTaskPool(opts taskPoolOptions) *taskPool {
opts.Validate()
func newTaskPool(opts TaskPoolOptions) *TaskPool {
opts.validate()
p := &taskPool{
taskPoolOptions: opts,
p := &TaskPool{
TaskPoolOptions: opts,
qArray: make([]chan task, opts.tQNumber),
done: make(chan struct{}),
}
for i := int32(0); i < p.tQNumber; i++ {
for i := 0; i < p.tQNumber; i++ {
p.qArray[i] = make(chan task, p.tQLen)
}
......@@ -71,8 +46,8 @@ func newTaskPool(opts taskPoolOptions) *taskPool {
}
// start task pool
func (p *taskPool) start() {
for i := int32(0); i < p.tQPoolSize; i++ {
func (p *TaskPool) start() {
for i := 0; i < p.tQPoolSize; i++ {
p.wg.Add(1)
workerID := i
q := p.qArray[workerID%p.tQNumber]
......@@ -81,7 +56,7 @@ func (p *taskPool) start() {
}
// worker
func (p *taskPool) run(id int, q chan task) {
func (p *TaskPool) run(id int, q chan task) {
defer p.wg.Done()
var (
......@@ -109,7 +84,7 @@ func (p *taskPool) run(id int, q chan task) {
}
// add task
func (p *taskPool) AddTask(t task) {
func (p *TaskPool) AddTask(t task) {
id := atomic.AddUint32(&p.idx, 1) % defaultTaskQNumber
select {
......@@ -120,7 +95,7 @@ func (p *taskPool) AddTask(t task) {
}
// stop all tasks
func (p *taskPool) stop() {
func (p *TaskPool) stop() {
select {
case <-p.done:
return
......@@ -132,7 +107,7 @@ func (p *taskPool) stop() {
}
// check whether the session has been closed.
func (p *taskPool) isClosed() bool {
func (p *TaskPool) isClosed() bool {
select {
case <-p.done:
return true
......@@ -142,7 +117,7 @@ func (p *taskPool) isClosed() bool {
}
}
func (p *taskPool) close() {
func (p *TaskPool) close() {
p.stop()
p.wg.Wait()
for i := range p.qArray {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment