Unverified Commit 8ef4d628 authored by XavierNiu's avatar XavierNiu Committed by GitHub

Ftr: ConnectionPool (#66)

* ftr: work pool

* rename WorkerPool -> FixedWorkerPool

* go fmt

* fix typo

* benchmark for ConnectionPool

* go fmt

* benchmark for ConnectionPool

* chan buffer extension for ConnectionPool benchmark

* reset benchmark tests for TaskPool

* reduce buffer of ConnectionPool

* unbounded chan benchmark

* reduce goroutine num

* ConnectionPool adopts task queue array

* apache license

* fix shadow variable

* remove done channel, add numWorkers

* go fmt

* fix wrong logger

* add newBaseWorkerPool

* remove ConnectionPoolConfig

* update comments

* update comments

* fix unittests

* Update base_worker_pool.go

* Update base_worker_pool.go

* fix newBaseWorkerPool bugs

* go fmt

* fix CountTaskSync bugs

* fix CountTask bug

* optimize code

* go fmt

* add SubmitSync test

* go fmt
Co-authored-by: 's avatarXin.Zh <dragoncharlie@foxmail.com>
parent 22480179
......@@ -115,9 +115,6 @@ func testQuota1(t *testing.T) {
ch.In() <- i
}
assert.True(t, 15 >= ch.Cap())
assert.True(t, 15 >= ch.Len())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
......@@ -125,18 +122,12 @@ func testQuota1(t *testing.T) {
ch.In() <- 15
}()
assert.True(t, 15 >= ch.Cap())
assert.True(t, 15 >= ch.Len())
for i := 0; i < 16; i++ {
v, ok := <-ch.Out()
assert.True(t, ok)
count += v.(int)
}
assert.True(t, 15 >= ch.Len())
assert.True(t, 10 >= ch.Cap())
wg.Wait()
assert.Equal(t, 165, count)
......@@ -225,3 +216,51 @@ func testQuota2(t *testing.T) {
default:
}
}
func BenchmarkUnboundedChan_Fixed(b *testing.B) {
ch := NewUnboundedChanWithQuota(1000, 1000)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
select {
case ch.In() <- 1:
}
<-ch.Out()
}
})
close(ch.In())
}
func BenchmarkUnboundedChan_Extension(b *testing.B) {
ch := NewUnboundedChanWithQuota(1000, 100000)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
select {
case ch.In() <- 1:
}
<-ch.Out()
}
})
close(ch.In())
}
func BenchmarkUnboundedChan_ExtensionUnlimited(b *testing.B) {
ch := NewUnboundedChanWithQuota(1000, 0)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
select {
case ch.In() <- 1:
}
<-ch.Out()
}
})
close(ch.In())
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gxlog
type Logger interface {
Info(args ...interface{})
Warn(args ...interface{})
Error(args ...interface{})
Debug(args ...interface{})
Infof(fmt string, args ...interface{})
Warnf(fmt string, args ...interface{})
Errorf(fmt string, args ...interface{})
Debugf(fmt string, args ...interface{})
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gxsync
import (
"fmt"
"runtime/debug"
"sync"
)
import (
"go.uber.org/atomic"
)
import (
gxlog "github.com/dubbogo/gost/log"
)
type WorkerPoolConfig struct {
NumWorkers int
NumQueues int
QueueSize int
Logger gxlog.Logger
}
// baseWorkerPool is a worker pool with multiple queues.
//
// The below picture shows baseWorkerPool architecture.
// Note that:
// - TaskQueueX is a channel with buffer, please refer to taskQueues.
// - Workers consume tasks in the dispatched queue only, please refer to dispatch(numWorkers).
// - taskId will be incremented by 1 after a task is enqueued.
// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────────────────┐
// │worker0│ │worker2│ │worker4│ ┌─┤ taskId % NumQueues == 0 │
// └───────┘ └───────┘ └───────┘ │ └─────────────────────────┘
// │ │ │ │
// └───────consume───────┘ enqueue
// ▼ task ╔══════════════════╗
// ┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┐ │ ║ baseWorkerPool: ║
// TaskQueue0 │t0│t1│t2│t3│t4│t5│t6│t7│t8│t9│◀─┘ ║ ║
// ├──┼──┼──┼──┼──┼──┼──┼──┼──┼──┤ ║ *NumWorkers=6 ║
// TaskQueue1 │t0│t1│t2│t3│t4│t5│t6│t7│t8│t9│◀┐ ║ *NumQueues=2 ║
// └──┴──┴──┴──┴──┴──┴──┴──┴──┴──┘ │ ║ *QueueSize=10 ║
// ▲ enqueue ╚══════════════════╝
// ┌───────consume───────┐ task
// │ │ │ │
// ┌───────┐ ┌───────┐ ┌───────┐ │ ┌─────────────────────────┐
// │worker1│ │worker3│ │worker5│ └──│ taskId % NumQueues == 1 │
// └───────┘ └───────┘ └───────┘ └─────────────────────────┘
type baseWorkerPool struct {
logger gxlog.Logger
taskId uint32
taskQueues []chan task
numWorkers *atomic.Int32
wg *sync.WaitGroup
}
func newBaseWorkerPool(config WorkerPoolConfig) *baseWorkerPool {
if config.NumWorkers < 1 {
config.NumWorkers = 1
}
if config.NumQueues < 1 {
config.NumQueues = 1
}
if config.QueueSize < 0 {
config.QueueSize = 0
}
taskQueues := make([]chan task, config.NumQueues)
for i := range taskQueues {
taskQueues[i] = make(chan task, config.QueueSize)
}
p := &baseWorkerPool{
logger: config.Logger,
taskQueues: taskQueues,
numWorkers: new(atomic.Int32),
wg: new(sync.WaitGroup),
}
initWg := new(sync.WaitGroup)
initWg.Add(config.NumWorkers)
p.dispatch(config.NumWorkers, initWg)
initWg.Wait()
return p
}
func (p *baseWorkerPool) dispatch(numWorkers int, wg *sync.WaitGroup) {
for i := 0; i < numWorkers; i++ {
p.newWorker(i, wg)
}
}
func (p *baseWorkerPool) Submit(t task) error {
panic("implement me")
}
func (p *baseWorkerPool) SubmitSync(t task) error {
panic("implement me")
}
func (p *baseWorkerPool) Close() {
if p.IsClosed() {
return
}
for _, q := range p.taskQueues {
close(q)
}
p.wg.Wait()
}
func (p *baseWorkerPool) IsClosed() bool {
return p.NumWorkers() == 0
}
func (p *baseWorkerPool) NumWorkers() int32 {
return p.numWorkers.Load()
}
func (p *baseWorkerPool) newWorker(workerId int, wg *sync.WaitGroup) {
p.wg.Add(1)
p.numWorkers.Add(1)
go p.worker(workerId, wg)
}
func (p *baseWorkerPool) worker(workerId int, wg *sync.WaitGroup) {
defer func() {
if n := p.numWorkers.Add(-1); n < 0 {
panic(fmt.Sprintf("numWorkers should be greater or equal to 0, but the value is %d", n))
}
p.wg.Done()
}()
if p.logger != nil {
p.logger.Infof("worker #%d is started\n", workerId)
}
chanId := workerId % len(p.taskQueues)
wg.Done()
for {
select {
case t, ok := <-p.taskQueues[chanId]:
if !ok {
if p.logger != nil {
p.logger.Infof("worker #%d is closed\n", workerId)
}
return
}
if t != nil {
func() {
// prevent from goroutine panic
defer func() {
if r := recover(); r != nil {
if p.logger != nil {
p.logger.Errorf("goroutine panic: %v\n%s\n", r, string(debug.Stack()))
}
}
}()
// execute task
t()
}()
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gxsync
import (
"math/rand"
"sync/atomic"
)
import (
perrors "github.com/pkg/errors"
)
var (
PoolBusyErr = perrors.New("pool is busy")
)
func NewConnectionPool(config WorkerPoolConfig) WorkerPool {
return &ConnectionPool{
baseWorkerPool: newBaseWorkerPool(config),
}
}
type ConnectionPool struct {
*baseWorkerPool
}
func (p *ConnectionPool) Submit(t task) error {
if t == nil {
return perrors.New("task shouldn't be nil")
}
// put the task to a queue using Round Robin algorithm
taskId := atomic.AddUint32(&p.taskId, 1)
select {
case p.taskQueues[int(taskId)%len(p.taskQueues)] <- t:
return nil
default:
}
// put the task to a random queue with a maximum of len(p.taskQueues)/2 attempts
for i := 0; i < len(p.taskQueues)/2; i++ {
select {
case p.taskQueues[rand.Intn(len(p.taskQueues))] <- t:
return nil
default:
continue
}
}
return PoolBusyErr
}
func (p *ConnectionPool) SubmitSync(t task) error {
done := make(chan struct{})
fn := func() {
defer close(done)
t()
}
if err := p.Submit(fn); err != nil {
return err
}
<-done
return nil
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gxsync
import (
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
import (
"github.com/stretchr/testify/assert"
)
func TestConnectionPool(t *testing.T) {
t.Run("Count", func(t *testing.T) {
p := NewConnectionPool(WorkerPoolConfig{
NumWorkers: 100,
NumQueues: runtime.NumCPU(),
QueueSize: 10,
Logger: nil,
})
var count int64
wg := new(sync.WaitGroup)
for i := 1; i <= 100; i++ {
wg.Add(1)
value := i
err := p.Submit(func() {
defer wg.Done()
atomic.AddInt64(&count, int64(value))
})
assert.Nil(t, err)
}
wg.Wait()
assert.Equal(t, int64(5050), count)
p.Close()
})
t.Run("PoolBusyErr", func(t *testing.T) {
p := NewConnectionPool(WorkerPoolConfig{
NumWorkers: 1,
NumQueues: 1,
QueueSize: 0,
Logger: nil,
})
wg := new(sync.WaitGroup)
wg.Add(1)
err := p.Submit(func() {
wg.Wait()
})
assert.Nil(t, err)
err = p.Submit(func() {})
assert.Equal(t, PoolBusyErr, err)
wg.Done()
time.Sleep(100 * time.Millisecond)
err = p.Submit(func() {})
assert.Nil(t, err)
p.Close()
})
t.Run("Close", func(t *testing.T) {
p := NewConnectionPool(WorkerPoolConfig{
NumWorkers: runtime.NumCPU(),
NumQueues: runtime.NumCPU(),
QueueSize: 100,
Logger: nil,
})
assert.Equal(t, runtime.NumCPU(), int(p.NumWorkers()))
p.Close()
assert.True(t, p.IsClosed())
assert.Panics(t, func() {
_ = p.Submit(func() {})
})
})
t.Run("BorderCondition", func(t *testing.T) {
p := NewConnectionPool(WorkerPoolConfig{
NumWorkers: 0,
NumQueues: runtime.NumCPU(),
QueueSize: 100,
Logger: nil,
})
assert.Equal(t, 1, int(p.NumWorkers()))
p.Close()
p = NewConnectionPool(WorkerPoolConfig{
NumWorkers: 1,
NumQueues: 0,
QueueSize: 0,
Logger: nil,
})
err := p.Submit(func() {})
assert.Nil(t, err)
p.Close()
p = NewConnectionPool(WorkerPoolConfig{
NumWorkers: 1,
NumQueues: 1,
QueueSize: -1,
Logger: nil,
})
err = p.Submit(func() {})
assert.Nil(t, err)
p.Close()
})
t.Run("NilTask", func(t *testing.T) {
p := NewConnectionPool(WorkerPoolConfig{
NumWorkers: 1,
NumQueues: 1,
QueueSize: 0,
Logger: nil,
})
err := p.Submit(nil)
assert.NotNil(t, err)
p.Close()
})
t.Run("CountTask", func(t *testing.T) {
p := NewConnectionPool(WorkerPoolConfig{
NumWorkers: runtime.NumCPU(),
NumQueues: runtime.NumCPU(),
QueueSize: 10,
Logger: nil,
})
task, v := newCountTask()
wg := new(sync.WaitGroup)
wg.Add(100)
for i := 0; i < 100; i++ {
if err := p.Submit(func() {
defer wg.Done()
task()
}); err != nil {
i--
}
}
wg.Wait()
assert.Equal(t, 100, int(*v))
p.Close()
})
t.Run("CountTaskSync", func(t *testing.T) {
p := NewConnectionPool(WorkerPoolConfig{
NumWorkers: runtime.NumCPU(),
NumQueues: runtime.NumCPU(),
QueueSize: 10,
Logger: nil,
})
task, v := newCountTask()
for i := 0; i < 100; i++ {
err := p.SubmitSync(task)
assert.Nil(t, err)
}
assert.Equal(t, 100, int(*v))
p.Close()
})
}
func BenchmarkConnectionPool(b *testing.B) {
p := NewConnectionPool(WorkerPoolConfig{
NumWorkers: 100,
NumQueues: runtime.NumCPU(),
QueueSize: 100,
Logger: nil,
})
b.Run("CountTask", func(b *testing.B) {
task, _ := newCountTask()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = p.Submit(task)
}
})
})
b.Run("CPUTask", func(b *testing.B) {
task, _ := newCPUTask()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = p.Submit(task)
}
})
})
b.Run("IOTask", func(b *testing.B) {
task, _ := newIOTask()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = p.Submit(task)
}
})
})
b.Run("RandomTask", func(b *testing.B) {
task, _ := newRandomTask()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = p.Submit(task)
}
})
})
}
......@@ -230,7 +230,7 @@ func TestTaskPool(t *testing.T) {
func BenchmarkTaskPool_CountTask(b *testing.B) {
tp := NewTaskPool(
WithTaskPoolTaskPoolSize(runtime.NumCPU()),
WithTaskPoolTaskPoolSize(100),
WithTaskPoolTaskQueueNumber(runtime.NumCPU()),
// WithTaskPoolTaskQueueLength(runtime.NumCPU()),
)
......@@ -266,7 +266,7 @@ func BenchmarkTaskPool_CountTask(b *testing.B) {
// cpu-intensive task
func BenchmarkTaskPool_CPUTask(b *testing.B) {
tp := NewTaskPool(
WithTaskPoolTaskPoolSize(runtime.NumCPU()),
WithTaskPoolTaskPoolSize(100),
WithTaskPoolTaskQueueNumber(runtime.NumCPU()),
// WithTaskPoolTaskQueueLength(runtime.NumCPU()),
)
......@@ -311,7 +311,7 @@ func BenchmarkTaskPool_CPUTask(b *testing.B) {
// IO-intensive task
func BenchmarkTaskPool_IOTask(b *testing.B) {
tp := NewTaskPool(
WithTaskPoolTaskPoolSize(runtime.NumCPU()),
WithTaskPoolTaskPoolSize(100),
WithTaskPoolTaskQueueNumber(runtime.NumCPU()),
// WithTaskPoolTaskQueueLength(runtime.NumCPU()),
)
......@@ -346,7 +346,7 @@ func BenchmarkTaskPool_IOTask(b *testing.B) {
func BenchmarkTaskPool_RandomTask(b *testing.B) {
tp := NewTaskPool(
WithTaskPoolTaskPoolSize(runtime.NumCPU()),
WithTaskPoolTaskPoolSize(100),
WithTaskPoolTaskQueueNumber(runtime.NumCPU()),
// WithTaskPoolTaskQueueLength(runtime.NumCPU()),
)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gxsync
type WorkerPool interface {
// Submit adds a task to queue asynchronously.
Submit(task) error
// SubmitSync adds a task to queue synchronously.
SubmitSync(task) error
// Close closes the worker pool
Close()
// IsClosed returns close status of the worker pool
IsClosed() bool
// NumWorkers returns the number of workers
NumWorkers() int32
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment