Commit 7c2d384d authored by alexstocks's avatar alexstocks

add NewWSSClient & RunWSEventLoopWithTLs

parent d5967faf
......@@ -10,6 +10,9 @@
package getty
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net"
"strings"
"sync"
......@@ -17,6 +20,7 @@ import (
)
import (
"fmt"
log "github.com/AlexStocks/log4go"
"github.com/gorilla/websocket"
)
......@@ -43,9 +47,12 @@ type Client struct {
sync.Once
done chan empty
wg sync.WaitGroup
// for wss client
certFile string
}
// NewClient function builds a client.
// NewClient function builds a tcp & ws client.
// @connNum is connection number.
// @connInterval is reconnect sleep interval when getty fails to connect the server.
// @serverAddr is server address.
......@@ -66,6 +73,29 @@ func NewClient(connNum int, connInterval time.Duration, serverAddr string) *Clie
}
}
// NewClient function builds a wss client.
// @connNum is connection number.
// @connInterval is reconnect sleep interval when getty fails to connect the server.
// @serverAddr is server address.
// @ cert is certificate file
func NewWSSClient(connNum int, connInterval time.Duration, serverAddr string, cert string) *Client {
if connNum < 0 {
connNum = 1
}
if connInterval < defaultInterval {
connInterval = defaultInterval
}
return &Client{
number: connNum,
interval: connInterval,
addr: serverAddr,
sessionMap: make(map[*Session]empty, connNum),
done: make(chan empty),
certFile: cert,
}
}
func (this *Client) dialTCP() *Session {
var (
err error
......@@ -115,9 +145,48 @@ func (this *Client) dialWS() *Session {
}
}
func (this *Client) dialWSS() *Session {
var (
err error
certPem []byte
certPool *x509.CertPool
conn *websocket.Conn
dialer websocket.Dialer
)
certPem, err = ioutil.ReadFile(this.certFile)
if err != nil {
panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err{%#v}", this.certFile, err))
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
panic("failed to parse root certificate")
}
dialer.TLSClientConfig = &tls.Config{RootCAs: certPool}
for {
if this.IsClosed() {
return nil
}
conn, _, err = dialer.Dial(this.addr, nil)
if err == nil && conn.LocalAddr().String() == conn.RemoteAddr().String() {
err = errSelfConnect
}
if err == nil {
return NewWSSession(conn)
}
log.Info("websocket.dialer.Dial(addr:%s) = error{%v}", this.addr, err)
time.Sleep(this.interval)
continue
}
}
func (this *Client) dial() *Session {
if strings.HasPrefix(this.addr, "ws") {
this.dialWS()
return this.dialWS()
} else if strings.HasPrefix(this.addr, "wss") {
return this.dialWSS()
}
return this.dialTCP()
......
......@@ -11,6 +11,11 @@
## develop history ##
---
- 2016/10/09
> 1 add client.go:NewWSSClient
>
> 2 add server.go:RunWSEventLoopWithTLS
- 2016/10/08
> 1 add websocket connection & client & server
>
......
......@@ -10,6 +10,7 @@
package getty
import (
"crypto/tls"
"errors"
"net"
"net/http"
......@@ -127,23 +128,27 @@ func (this *Server) RunEventloop(newSession NewSessionCallback) {
}
type wsHandler struct {
http.ServeMux
server *Server
newSession NewSessionCallback
upgrader websocket.Upgrader
}
func newWSHandler(server *Server) wsHandler {
return wsHandler{
server: server,
func newWSHandler(server *Server, newSession NewSessionCallback) *wsHandler {
return &wsHandler{
server: server,
newSession: newSession,
upgrader: websocket.Upgrader{
// in default, ReadBufferSize & WriteBufferSize is 4k
// HandshakeTimeout: server.HTTPTimeout,
CheckOrigin: func(_ *http.Request) bool { return true },
},
}
}
func (this wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (this *wsHandler) ServeWSRequest(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
// w.WriteHeader(http.StatusMethodNotAllowed)
http.Error(w, "Method not allowed", 405)
return
}
......@@ -174,16 +179,67 @@ func (this wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
session.RunEventLoop()
}
func (this *Server) RunWSEventLoop(newSession NewSessionCallback) {
// RunWSEventLoop serve websocket client request
// @newSession: new websocket connection callback
// @path: websocket request url path
func (this *Server) RunWSEventLoop(newSession NewSessionCallback, path string) {
this.wg.Add(1)
go func() {
defer this.wg.Done()
(&http.Server{
var (
err error
handler *wsHandler
)
handler = newWSHandler(this, newSession)
handler.HandleFunc(path, handler.ServeWSRequest)
err = (&http.Server{
Addr: this.addr,
Handler: newWSHandler(this),
Handler: handler,
// ReadTimeout: server.HTTPTimeout,
// WriteTimeout: server.HTTPTimeout,
}).Serve(this.listener)
if err != nil {
log.Error("http.Server.Serve(addr{%s}) = err{%#v}", this.addr, err)
panic(err)
}
}()
}
// RunWSEventLoopWithTLS serve websocket client request
// @newSession: new websocket connection callback
// @path: websocket request url path
func (this *Server) RunWSEventLoopWithTLS(newSession NewSessionCallback, path string, cert string, priv string) {
this.wg.Add(1)
go func() {
defer this.wg.Done()
var (
err error
config *tls.Config
handler *wsHandler
server *http.Server
)
config = &tls.Config{}
config.Certificates = make([]tls.Certificate, 1)
if config.Certificates[0], err = tls.LoadX509KeyPair(cert, priv); err != nil {
log.Error("tls.LoadX509KeyPair(cert{%s}, priv{%s}) = err{%#v}", cert, priv, err)
return
}
handler = newWSHandler(this, newSession)
handler.HandleFunc(path, handler.ServeWSRequest)
server = &http.Server{
Addr: this.addr,
Handler: handler,
// ReadTimeout: server.HTTPTimeout,
// WriteTimeout: server.HTTPTimeout,
}
server.SetKeepAlivesEnabled(true)
err = server.Serve(tls.NewListener(this.listener, config))
if err != nil {
log.Error("http.Server.Serve(addr{%s}) = err{%#v}", this.addr, err)
panic(err)
}
}()
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment