Commit 97fc48e6 authored by AlexStocks's avatar AlexStocks

add wss

parent 3e3e7190
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
) )
import ( import (
"github.com/AlexStocks/goext/log"
"github.com/AlexStocks/goext/sync" "github.com/AlexStocks/goext/sync"
log "github.com/AlexStocks/log4go" log "github.com/AlexStocks/log4go"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
...@@ -50,7 +51,9 @@ type Client struct { ...@@ -50,7 +51,9 @@ type Client struct {
wg sync.WaitGroup wg sync.WaitGroup
// for wss client // for wss client
certFile string cert string // 客户端的证书
privateKey string // 客户端的私钥(包含了它的public key)
caCert string // 用于验证服务端的合法性
} }
// NewClient function builds a tcp & ws client. // NewClient function builds a tcp & ws client.
...@@ -78,8 +81,18 @@ func NewClient(connNum int, connInterval time.Duration, serverAddr string) *Clie ...@@ -78,8 +81,18 @@ func NewClient(connNum int, connInterval time.Duration, serverAddr string) *Clie
// @connNum is connection number. // @connNum is connection number.
// @connInterval is reconnect sleep interval when getty fails to connect the server. // @connInterval is reconnect sleep interval when getty fails to connect the server.
// @serverAddr is server address. // @serverAddr is server address.
// @ cert is certificate file // @cert is client certificate file. it can be emtpy.
func NewWSSClient(connNum int, connInterval time.Duration, serverAddr string, cert string) *Client { // @privateKey is client private key(contains its public key). it can be empty.
// @caCert is the root certificate file to verify the legitimacy of server
func NewWSSClient(
connNum int,
connInterval time.Duration,
serverAddr string,
cert string,
privateKey string,
caCert string,
) *Client {
if connNum < 0 { if connNum < 0 {
connNum = 1 connNum = 1
} }
...@@ -88,12 +101,14 @@ func NewWSSClient(connNum int, connInterval time.Duration, serverAddr string, ce ...@@ -88,12 +101,14 @@ func NewWSSClient(connNum int, connInterval time.Duration, serverAddr string, ce
} }
return &Client{ return &Client{
number: connNum, number: connNum,
interval: connInterval, interval: connInterval,
addr: serverAddr, addr: serverAddr,
ssMap: make(map[Session]gxsync.Empty, connNum), ssMap: make(map[Session]gxsync.Empty, connNum),
done: make(chan gxsync.Empty), done: make(chan gxsync.Empty),
certFile: cert, caCert: caCert,
cert: cert,
privateKey: privateKey,
} }
} }
...@@ -135,6 +150,7 @@ func (c *Client) dialWS() Session { ...@@ -135,6 +150,7 @@ func (c *Client) dialWS() Session {
return nil return nil
} }
conn, _, err = dialer.Dial(c.addr, nil) conn, _, err = dialer.Dial(c.addr, nil)
log.Info("websocket.dialer.Dial(addr:%s) = error{%v}", c.addr, err)
if err == nil && conn.LocalAddr().String() == conn.RemoteAddr().String() { if err == nil && conn.LocalAddr().String() == conn.RemoteAddr().String() {
err = errSelfConnect err = errSelfConnect
} }
...@@ -158,23 +174,41 @@ func (c *Client) dialWSS() Session { ...@@ -158,23 +174,41 @@ func (c *Client) dialWSS() Session {
err error err error
certPem []byte certPem []byte
certPool *x509.CertPool certPool *x509.CertPool
config *tls.Config
dialer websocket.Dialer dialer websocket.Dialer
conn *websocket.Conn conn *websocket.Conn
ss Session ss Session
) )
dialer.EnableCompression = true dialer.EnableCompression = true
certPem, err = ioutil.ReadFile(c.certFile)
if err != nil { config = &tls.Config{
panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err{%#v}", c.certFile, err)) // InsecureSkipVerify: true,
} }
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok { gxlog.CInfo("client cert:%s, key:%s, caCert:%s", c.cert, c.privateKey, c.caCert)
panic("failed to parse root certificate") if c.caCert != "" {
certPem, err = ioutil.ReadFile(c.caCert)
if err != nil {
panic(fmt.Errorf("ioutil.ReadFile(caCert{%s}) = err{%#v}", c.caCert, err))
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
panic("failed to parse root certificate file.")
}
config.RootCAs = certPool
config.InsecureSkipVerify = false
}
if c.cert != "" && c.privateKey != "" {
config.Certificates = make([]tls.Certificate, 1)
if config.Certificates[0], err = tls.LoadX509KeyPair(c.cert, c.privateKey); err != nil {
panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%#v}", c.cert, c.privateKey, err))
}
} }
// dialer.EnableCompression = true // dialer.EnableCompression = true
dialer.TLSClientConfig = &tls.Config{RootCAs: certPool} dialer.TLSClientConfig = config
for { for {
if c.IsClosed() { if c.IsClosed() {
return nil return nil
...@@ -208,7 +242,7 @@ func (c *Client) dial() Session { ...@@ -208,7 +242,7 @@ func (c *Client) dial() Session {
return c.dialTCP() return c.dialTCP()
} }
func (c *Client) ssNum() int { func (c *Client) sessionNum() int {
var num int var num int
c.Lock() c.Lock()
...@@ -270,7 +304,7 @@ func (c *Client) RunEventLoop(newSession NewSessionCallback) { ...@@ -270,7 +304,7 @@ func (c *Client) RunEventLoop(newSession NewSessionCallback) {
break break
} }
num = c.ssNum() num = c.sessionNum()
// log.Info("current client connction number:%d", num) // log.Info("current client connction number:%d", num)
if max <= num { if max <= num {
times++ times++
......
...@@ -12,7 +12,10 @@ package getty ...@@ -12,7 +12,10 @@ package getty
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors" "errors"
"fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"sync" "sync"
...@@ -20,6 +23,7 @@ import ( ...@@ -20,6 +23,7 @@ import (
) )
import ( import (
"github.com/AlexStocks/goext/log"
"github.com/AlexStocks/goext/net" "github.com/AlexStocks/goext/net"
"github.com/AlexStocks/goext/sync" "github.com/AlexStocks/goext/sync"
"github.com/AlexStocks/goext/time" "github.com/AlexStocks/goext/time"
...@@ -29,7 +33,7 @@ import ( ...@@ -29,7 +33,7 @@ import (
var ( var (
errSelfConnect = errors.New("connect self!") errSelfConnect = errors.New("connect self!")
serverFastFailTimeout = gxtime.TimeSecondDuration(2) serverFastFailTimeout = gxtime.TimeSecondDuration(1)
) )
type Server struct { type Server struct {
...@@ -236,28 +240,54 @@ func (s *Server) RunWSEventLoop(newSession NewSessionCallback, path string) { ...@@ -236,28 +240,54 @@ func (s *Server) RunWSEventLoop(newSession NewSessionCallback, path string) {
}() }()
} }
// serve websocket client request // serve websocket client request
// RunWSSEventLoop serve websocket client request // RunWSSEventLoop serve websocket client request
// @newSession: new websocket connection callback // @newSession: new websocket connection callback
// @path: websocket request url path // @path: websocket request url path
func (s *Server) RunWSSEventLoop(newSession NewSessionCallback, path string, cert string, priv string) { // @cert: server certificate file
// @privateKey: server private key(contains its public key)
// @caCert: root certificate file. to verify the legitimacy of client. it can be nil.
func (s *Server) RunWSSEventLoop(
newSession NewSessionCallback,
path string,
cert string,
privateKey string,
caCert string) {
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
defer s.wg.Done() defer s.wg.Done()
var ( var (
err error err error
config *tls.Config certPem []byte
handler *wsHandler certPool *x509.CertPool
server *http.Server config *tls.Config
handler *wsHandler
server *http.Server
) )
config = &tls.Config{} config = &tls.Config{InsecureSkipVerify: true}
config.Certificates = make([]tls.Certificate, 1) config.Certificates = make([]tls.Certificate, 1)
if config.Certificates[0], err = tls.LoadX509KeyPair(cert, priv); err != nil { gxlog.CInfo("server cert:%s, key:%s, caCert:%s", cert, privateKey, caCert)
log.Error("tls.LoadX509KeyPair(cert{%s}, priv{%s}) = err{%#v}", cert, priv, err) if config.Certificates[0], err = tls.LoadX509KeyPair(cert, privateKey); err != nil {
log.Error("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%#v}", cert, privateKey, err)
return return
} }
if caCert != "" {
certPem, err = ioutil.ReadFile(caCert)
if err != nil {
panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err{%#v}", caCert, err))
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
panic("failed to parse root certificate file")
}
config.ClientCAs = certPool
config.ClientAuth = tls.RequireAndVerifyClientCert
config.InsecureSkipVerify = false
}
handler = newWSHandler(s, newSession) handler = newWSHandler(s, newSession)
handler.HandleFunc(path, handler.serveWSRequest) handler.HandleFunc(path, handler.serveWSRequest)
server = &http.Server{ server = &http.Server{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment