Commit 92f8562c authored by AlexStocks's avatar AlexStocks

fix wss client bug

parent 75b0c121
...@@ -11,6 +11,12 @@ ...@@ -11,6 +11,12 @@
## develop history ## ## develop history ##
--- ---
- 2017/04/21
> bug fix
* 1 client can not connect wss server because of getty does not verify whether cert&key is nil or not in client.go:dialWSS
> version: 0.7.02
- 2017/02/08 - 2017/02/08
> improvement > improvement
> >
......
...@@ -173,6 +173,8 @@ func (c *Client) dialWSS() Session { ...@@ -173,6 +173,8 @@ func (c *Client) dialWSS() Session {
var ( var (
err error err error
certPem []byte certPem []byte
root *x509.Certificate
roots []*x509.Certificate
certPool *x509.CertPool certPool *x509.CertPool
config *tls.Config config *tls.Config
dialer websocket.Dialer dialer websocket.Dialer
...@@ -186,26 +188,36 @@ func (c *Client) dialWSS() Session { ...@@ -186,26 +188,36 @@ func (c *Client) dialWSS() Session {
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
if c.cert != "" && c.privateKey != "" {
config.Certificates = make([]tls.Certificate, 1)
if config.Certificates[0], err = tls.LoadX509KeyPair(c.cert, c.privateKey); err != nil {
panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%#v}", c.cert, c.privateKey, err))
}
}
certPool = x509.NewCertPool()
for _, c := range config.Certificates {
roots, err = x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
panic(fmt.Sprintf("error parsing server's root cert: %v\n", err))
}
for _, root = range roots {
certPool.AddCert(root)
}
}
gxlog.CInfo("client cert:%s, key:%s, caCert:%s", c.cert, c.privateKey, c.caCert) gxlog.CInfo("client cert:%s, key:%s, caCert:%s", c.cert, c.privateKey, c.caCert)
if c.caCert != "" { if c.caCert != "" {
certPem, err = ioutil.ReadFile(c.caCert) certPem, err = ioutil.ReadFile(c.caCert)
if err != nil { if err != nil {
panic(fmt.Errorf("ioutil.ReadFile(caCert{%s}) = err{%#v}", c.caCert, err)) panic(fmt.Errorf("ioutil.ReadFile(caCert{%s}) = err{%#v}", c.caCert, err))
} }
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok { if ok := certPool.AppendCertsFromPEM(certPem); !ok {
panic("failed to parse root certificate file.") panic("failed to parse root certificate file.")
} }
config.RootCAs = certPool
config.InsecureSkipVerify = false config.InsecureSkipVerify = false
} }
config.RootCAs = certPool
if c.cert != "" && c.privateKey != "" {
config.Certificates = make([]tls.Certificate, 1)
if config.Certificates[0], err = tls.LoadX509KeyPair(c.cert, c.privateKey); err != nil {
panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%#v}", c.cert, c.privateKey, err))
}
}
// dialer.EnableCompression = true // dialer.EnableCompression = true
dialer.TLSClientConfig = config dialer.TLSClientConfig = config
...@@ -222,6 +234,7 @@ func (c *Client) dialWSS() Session { ...@@ -222,6 +234,7 @@ func (c *Client) dialWSS() Session {
if ss.(*session).maxMsgLen > 0 { if ss.(*session).maxMsgLen > 0 {
conn.SetReadLimit(int64(ss.(*session).maxMsgLen)) conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
} }
ss.SetName(defaultWSSSessionName)
return ss return ss
} }
...@@ -233,10 +246,12 @@ func (c *Client) dialWSS() Session { ...@@ -233,10 +246,12 @@ func (c *Client) dialWSS() Session {
} }
func (c *Client) dial() Session { func (c *Client) dial() Session {
if strings.HasPrefix(c.addr, "wss") {
return c.dialWSS()
}
if strings.HasPrefix(c.addr, "ws") { if strings.HasPrefix(c.addr, "ws") {
return c.dialWS() return c.dialWS()
} else if strings.HasPrefix(c.addr, "wss") {
return c.dialWSS()
} }
return c.dialTCP() return c.dialTCP()
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
package getty package getty
import ( import (
// "context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
...@@ -23,7 +23,6 @@ import ( ...@@ -23,7 +23,6 @@ import (
) )
import ( import (
"github.com/AlexStocks/goext/log"
"github.com/AlexStocks/goext/net" "github.com/AlexStocks/goext/net"
"github.com/AlexStocks/goext/sync" "github.com/AlexStocks/goext/sync"
"github.com/AlexStocks/goext/time" "github.com/AlexStocks/goext/time"
...@@ -54,8 +53,8 @@ func NewServer() *Server { ...@@ -54,8 +53,8 @@ func NewServer() *Server {
func (s *Server) stop() { func (s *Server) stop() {
var ( var (
// err error err error
// ctx context.Context ctx context.Context
) )
select { select {
case <-s.done: case <-s.done:
...@@ -65,12 +64,12 @@ func (s *Server) stop() { ...@@ -65,12 +64,12 @@ func (s *Server) stop() {
close(s.done) close(s.done)
s.lock.Lock() s.lock.Lock()
if s.server != nil { if s.server != nil {
// ctx, _ = context.WithTimeout(context.Background(), serverFastFailTimeout) ctx, _ = context.WithTimeout(context.Background(), serverFastFailTimeout)
// if err = s.server.Shutdown(ctx); err != nil { if err = s.server.Shutdown(ctx); err != nil {
// // 如果下面内容输出为:server shutdown ctx: context deadline exceeded, // 如果下面内容输出为:server shutdown ctx: context deadline exceeded,
// // 则说明有未处理完的active connections。 // 则说明有未处理完的active connections。
// log.Error("server shutdown ctx:%#v", err) log.Error("server shutdown ctx:%#v", err)
// } }
} }
s.lock.Unlock() s.lock.Unlock()
// 把listener.Close放在这里,既能防止多次关闭调用, // 把listener.Close放在这里,既能防止多次关闭调用,
...@@ -252,30 +251,32 @@ func (s *Server) RunWSSEventLoop( ...@@ -252,30 +251,32 @@ func (s *Server) RunWSSEventLoop(
path string, path string,
cert string, cert string,
privateKey string, privateKey string,
caCert string) { caCert string,
) {
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
defer s.wg.Done() defer s.wg.Done()
var ( var (
err error err error
certPem []byte certPem []byte
certPool *x509.CertPool certificate tls.Certificate
config *tls.Config certPool *x509.CertPool
handler *wsHandler config *tls.Config
server *http.Server handler *wsHandler
server *http.Server
) )
config = &tls.Config{ if certificate, err = tls.LoadX509KeyPair(cert, privateKey); err != nil {
InsecureSkipVerify: true,
ClientAuth: tls.NoClientCert,
}
config.Certificates = make([]tls.Certificate, 1)
gxlog.CInfo("server cert:%s, key:%s, caCert:%s", cert, privateKey, caCert)
if config.Certificates[0], err = tls.LoadX509KeyPair(cert, privateKey); err != nil {
panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%#v}", cert, privateKey, err)) panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%#v}", cert, privateKey, err))
return return
} }
config = &tls.Config{
InsecureSkipVerify: true, // 不对对端的证书进行校验
ClientAuth: tls.NoClientCert,
NextProtos: []string{"http/1.1"},
Certificates: []tls.Certificate{certificate},
}
if caCert != "" { if caCert != "" {
certPem, err = ioutil.ReadFile(caCert) certPem, err = ioutil.ReadFile(caCert)
......
...@@ -28,12 +28,15 @@ import ( ...@@ -28,12 +28,15 @@ import (
) )
const ( const (
maxReadBufLen = 4 * 1024 maxReadBufLen = 4 * 1024
netIOTimeout = 1e9 // 1s netIOTimeout = 1e9 // 1s
period = 60 * 1e9 // 1 minute period = 60 * 1e9 // 1 minute
pendingDuration = 3e9 pendingDuration = 3e9
defaultSessionName = "session" defaultSessionName = "session"
outputFormat = "session %s, Read Count: %d, Write Count: %d, Read Pkg Count: %d, Write Pkg Count: %d" defaultTCPSessionName = "tcp-session"
defaultWSSessionName = "ws-session"
defaultWSSSessionName = "wss-session"
outputFormat = "session %s, Read Count: %d, Write Count: %d, Read Pkg Count: %d, Write Pkg Count: %d"
) )
///////////////////////////////////////// /////////////////////////////////////////
...@@ -121,7 +124,7 @@ func NewSession() Session { ...@@ -121,7 +124,7 @@ func NewSession() Session {
func NewTCPSession(conn net.Conn) Session { func NewTCPSession(conn net.Conn) Session {
session := &session{ session := &session{
name: defaultSessionName, name: defaultTCPSessionName,
Connection: newGettyTCPConn(conn), Connection: newGettyTCPConn(conn),
done: make(chan gxsync.Empty), done: make(chan gxsync.Empty),
period: period, period: period,
...@@ -137,7 +140,7 @@ func NewTCPSession(conn net.Conn) Session { ...@@ -137,7 +140,7 @@ func NewTCPSession(conn net.Conn) Session {
func NewWSSession(conn *websocket.Conn) Session { func NewWSSession(conn *websocket.Conn) Session {
session := &session{ session := &session{
name: defaultSessionName, name: defaultWSSessionName,
Connection: newGettyWSConn(conn), Connection: newGettyWSConn(conn),
done: make(chan gxsync.Empty), done: make(chan gxsync.Empty),
period: period, period: period,
......
...@@ -10,9 +10,9 @@ ...@@ -10,9 +10,9 @@
package getty package getty
const ( const (
Version = "0.7.01" Version = "0.7.02"
DATE = "2017/02/08" DATE = "2017/04/21"
GETTY_MAJOR = 0 GETTY_MAJOR = 0
GETTY_MINOR = 7 GETTY_MINOR = 7
GETTY_BUILD = 1 GETTY_BUILD = 2
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment