Commit 97fc48e6 authored by AlexStocks's avatar AlexStocks

add wss

parent 3e3e7190
......@@ -21,6 +21,7 @@ import (
)
import (
"github.com/AlexStocks/goext/log"
"github.com/AlexStocks/goext/sync"
log "github.com/AlexStocks/log4go"
"github.com/gorilla/websocket"
......@@ -50,7 +51,9 @@ type Client struct {
wg sync.WaitGroup
// for wss client
certFile string
cert string // 客户端的证书
privateKey string // 客户端的私钥(包含了它的public key)
caCert string // 用于验证服务端的合法性
}
// NewClient function builds a tcp & ws client.
......@@ -78,8 +81,18 @@ func NewClient(connNum int, connInterval time.Duration, serverAddr string) *Clie
// @connNum is connection number.
// @connInterval is reconnect sleep interval when getty fails to connect the server.
// @serverAddr is server address.
// @ cert is certificate file
func NewWSSClient(connNum int, connInterval time.Duration, serverAddr string, cert string) *Client {
// @cert is client certificate file. it can be emtpy.
// @privateKey is client private key(contains its public key). it can be empty.
// @caCert is the root certificate file to verify the legitimacy of server
func NewWSSClient(
connNum int,
connInterval time.Duration,
serverAddr string,
cert string,
privateKey string,
caCert string,
) *Client {
if connNum < 0 {
connNum = 1
}
......@@ -93,7 +106,9 @@ func NewWSSClient(connNum int, connInterval time.Duration, serverAddr string, ce
addr: serverAddr,
ssMap: make(map[Session]gxsync.Empty, connNum),
done: make(chan gxsync.Empty),
certFile: cert,
caCert: caCert,
cert: cert,
privateKey: privateKey,
}
}
......@@ -135,6 +150,7 @@ func (c *Client) dialWS() Session {
return nil
}
conn, _, err = dialer.Dial(c.addr, nil)
log.Info("websocket.dialer.Dial(addr:%s) = error{%v}", c.addr, err)
if err == nil && conn.LocalAddr().String() == conn.RemoteAddr().String() {
err = errSelfConnect
}
......@@ -158,23 +174,41 @@ func (c *Client) dialWSS() Session {
err error
certPem []byte
certPool *x509.CertPool
config *tls.Config
dialer websocket.Dialer
conn *websocket.Conn
ss Session
)
dialer.EnableCompression = true
certPem, err = ioutil.ReadFile(c.certFile)
config = &tls.Config{
// InsecureSkipVerify: true,
}
gxlog.CInfo("client cert:%s, key:%s, caCert:%s", c.cert, c.privateKey, c.caCert)
if c.caCert != "" {
certPem, err = ioutil.ReadFile(c.caCert)
if err != nil {
panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err{%#v}", c.certFile, err))
panic(fmt.Errorf("ioutil.ReadFile(caCert{%s}) = err{%#v}", c.caCert, err))
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
panic("failed to parse root certificate")
panic("failed to parse root certificate file.")
}
config.RootCAs = certPool
config.InsecureSkipVerify = false
}
if c.cert != "" && c.privateKey != "" {
config.Certificates = make([]tls.Certificate, 1)
if config.Certificates[0], err = tls.LoadX509KeyPair(c.cert, c.privateKey); err != nil {
panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%#v}", c.cert, c.privateKey, err))
}
}
// dialer.EnableCompression = true
dialer.TLSClientConfig = &tls.Config{RootCAs: certPool}
dialer.TLSClientConfig = config
for {
if c.IsClosed() {
return nil
......@@ -208,7 +242,7 @@ func (c *Client) dial() Session {
return c.dialTCP()
}
func (c *Client) ssNum() int {
func (c *Client) sessionNum() int {
var num int
c.Lock()
......@@ -270,7 +304,7 @@ func (c *Client) RunEventLoop(newSession NewSessionCallback) {
break
}
num = c.ssNum()
num = c.sessionNum()
// log.Info("current client connction number:%d", num)
if max <= num {
times++
......
......@@ -12,7 +12,10 @@ package getty
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"sync"
......@@ -20,6 +23,7 @@ import (
)
import (
"github.com/AlexStocks/goext/log"
"github.com/AlexStocks/goext/net"
"github.com/AlexStocks/goext/sync"
"github.com/AlexStocks/goext/time"
......@@ -29,7 +33,7 @@ import (
var (
errSelfConnect = errors.New("connect self!")
serverFastFailTimeout = gxtime.TimeSecondDuration(2)
serverFastFailTimeout = gxtime.TimeSecondDuration(1)
)
type Server struct {
......@@ -240,24 +244,50 @@ func (s *Server) RunWSEventLoop(newSession NewSessionCallback, path string) {
// RunWSSEventLoop serve websocket client request
// @newSession: new websocket connection callback
// @path: websocket request url path
func (s *Server) RunWSSEventLoop(newSession NewSessionCallback, path string, cert string, priv string) {
// @cert: server certificate file
// @privateKey: server private key(contains its public key)
// @caCert: root certificate file. to verify the legitimacy of client. it can be nil.
func (s *Server) RunWSSEventLoop(
newSession NewSessionCallback,
path string,
cert string,
privateKey string,
caCert string) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
var (
err error
certPem []byte
certPool *x509.CertPool
config *tls.Config
handler *wsHandler
server *http.Server
)
config = &tls.Config{}
config = &tls.Config{InsecureSkipVerify: true}
config.Certificates = make([]tls.Certificate, 1)
if config.Certificates[0], err = tls.LoadX509KeyPair(cert, priv); err != nil {
log.Error("tls.LoadX509KeyPair(cert{%s}, priv{%s}) = err{%#v}", cert, priv, err)
gxlog.CInfo("server cert:%s, key:%s, caCert:%s", cert, privateKey, caCert)
if config.Certificates[0], err = tls.LoadX509KeyPair(cert, privateKey); err != nil {
log.Error("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%#v}", cert, privateKey, err)
return
}
if caCert != "" {
certPem, err = ioutil.ReadFile(caCert)
if err != nil {
panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err{%#v}", caCert, err))
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
panic("failed to parse root certificate file")
}
config.ClientCAs = certPool
config.ClientAuth = tls.RequireAndVerifyClientCert
config.InsecureSkipVerify = false
}
handler = newWSHandler(s, newSession)
handler.HandleFunc(path, handler.serveWSRequest)
server = &http.Server{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment