Commit a153e2e0 authored by aliiohs's avatar aliiohs

add tls config builder interface

parent 081c4ef0
......@@ -152,10 +152,11 @@ func (c *client) dialTCP() Session {
if c.IsClosed() {
return nil
}
sslConfig := c.loadSslConfig()
if c.sslEnabled && sslConfig != nil {
d := &net.Dialer{Timeout: connectTimeout}
conn, err = tls.DialWithDialer(d, "tcp", c.addr, sslConfig)
if c.sslEnabled {
if sslConfig, err := c.tlsConfigBuilder.BuildTlsConfig(); err == nil && sslConfig != nil {
d := &net.Dialer{Timeout: connectTimeout}
conn, err = tls.DialWithDialer(d, "tcp", c.addr, sslConfig)
}
} else {
conn, err = net.DialTimeout("tcp", c.addr, connectTimeout)
}
......@@ -475,71 +476,3 @@ func (c *client) Close() {
c.stop()
c.wg.Wait()
}
func (c *client) loadSslConfig() *tls.Config {
/* var (
err error
root *x509.Certificate
roots []*x509.Certificate
certPool *x509.CertPool
config *tls.Config
)
config = &tls.Config{
InsecureSkipVerify: true,
}
if c.clientTrustCertCollectionPath != "" {
certPEMBlock, err := ioutil.ReadFile(c.clientTrustCertCollectionPath)
if err != nil {
panic(fmt.Sprintf("ioutil.ReadFile(cert:%s) = error:%+v", c.clientTrustCertCollectionPath, perrors.WithStack(err)))
}
var cert tls.Certificate
for {
var certDERBlock *pem.Block
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0] = cert
}
certPool = x509.NewCertPool()
for _, ce := range config.Certificates {
roots, err = x509.ParseCertificates(ce.Certificate[len(ce.Certificate)-1])
if err != nil {
panic(fmt.Sprintf("error parsing server's root cert: %+v\n", perrors.WithStack(err)))
}
for _, root = range roots {
certPool.AddCert(root)
}
}
config.InsecureSkipVerify = true
config.RootCAs = certPool
return config
*/
cert, err := tls.LoadX509KeyPair(c.clientTrustCertCollectionPath, c.clientPrivateKeyPath)
if err != nil {
s := fmt.Sprintf("Unable to load X509 Key Pair %v", err)
panic(s)
}
certBytes, err := ioutil.ReadFile(c.clientTrustCertCollectionPath)
if err != nil {
panic("Unable to read cert.pem")
}
clientCertPool := x509.NewCertPool()
ok := clientCertPool.AppendCertsFromPEM(certBytes)
if !ok {
panic("failed to parse root certificate")
}
return &tls.Config{
RootCAs: clientCertPool,
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}
}
......@@ -323,10 +323,10 @@ func (t *gettyTCPConn) close(waitSec int) {
}
}
if conn, ok := t.conn.(*net.TCPConn); ok {
conn.SetLinger(waitSec)
conn.Close()
_ = conn.SetLinger(waitSec)
_ = conn.Close()
} else {
t.conn.(*tls.Conn).Close()
_ = t.conn.(*tls.Conn).Close()
}
t.conn = nil
......
......@@ -61,12 +61,14 @@ func main() {
gxsync.WithTaskPoolTaskPoolSize(*taskPoolSize),
)
}
config := &getty.ClientTlsConfigBuilder{
ClientTrustCertCollectionPath: `E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-consumer\src\main\resources\certs\ca.pem`,
ClientPrivateKeyPath: `E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-provider\src\main\resources\certs\ca.key`,
}
client := getty.NewTCPClient(
getty.WithServerAddress(*ip+":8090"),
getty.WithClientSslEnabled(true),
getty.WithClientTrustCertCollectionPath(`E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-consumer\src\main\resources\certs\ca.pem`),
getty.WithClientPrivateKeyPath(`E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-provider\src\main\resources\certs\ca.key`),
getty.WithClientTlsConfigBuilder(config),
getty.WithConnectionNumber(*connections),
)
......
......@@ -50,12 +50,15 @@ func main() {
util.SetLimit()
util.Profiling(*pprofPort)
c := &getty.ServerTlsConfigBuilder{
ServerKeyCertChainPath: `E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-provider\src\main\resources\certs\server0.pem`,
ServerPrivateKeyPath: `E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-provider\src\main\resources\certs\server0.key`,
ServerTrustCertCollectionPath: `E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-consumer\src\main\resources\certs\ca.pem`,
}
options := []getty.ServerOption{getty.WithLocalAddress(":8090"),
getty.WithServerSslEnabled(true),
getty.WithServerKeyCertChainPath(`E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-provider\src\main\resources\certs\server0.pem`),
getty.WithServerPrivateKeyPath(`E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-provider\src\main\resources\certs\server0.key`),
getty.WithServerTrustCertCollectionPath(`E:\Projects\openSource\dubbo-samples\java\dubbo-samples-ssl\dubbo-samples-ssl-consumer\src\main\resources\certs\ca.pem`),
getty.WithServerTlsConfigBuilder(c),
}
if *taskPoolMode {
......
......@@ -26,12 +26,8 @@ type ServerOption func(*ServerOptions)
type ServerOptions struct {
addr string
//tls
sslEnabled bool
serverKeyCertChainPath string
serverPrivateKeyPath string
serverKeyPassword string
serverTrustCertCollectionPath string
sslEnabled bool
tlsConfigBuilder TlsConfigBuilder
// websocket
path string
cert string
......@@ -82,30 +78,9 @@ func WithServerSslEnabled(sslEnabled bool) ServerOption {
}
// @WithServerKeyCertChainPath sslConfig is tls config
func WithServerKeyCertChainPath(serverKeyCertChainPath string) ServerOption {
return func(o *ServerOptions) {
o.serverKeyCertChainPath = serverKeyCertChainPath
}
}
// @WithServerPrivateKeyPath sslConfig is tls config
func WithServerPrivateKeyPath(serverPrivateKeyPath string) ServerOption {
return func(o *ServerOptions) {
o.serverPrivateKeyPath = serverPrivateKeyPath
}
}
// @WithServerKeyPassword sslConfig is tls config
func WithServerKeyPassword(serverKeyPassword string) ServerOption {
return func(o *ServerOptions) {
o.serverKeyPassword = serverKeyPassword
}
}
// @WithServerTrustCertCollectionPath sslConfig is tls config
func WithServerTrustCertCollectionPath(serverTrustCertCollectionPath string) ServerOption {
func WithServerTlsConfigBuilder(tlsConfigBuilder TlsConfigBuilder) ServerOption {
return func(o *ServerOptions) {
o.serverTrustCertCollectionPath = serverTrustCertCollectionPath
o.tlsConfigBuilder = tlsConfigBuilder
}
}
......@@ -121,11 +96,8 @@ type ClientOptions struct {
reconnectInterval int // reConnect Interval
//tls
sslEnabled bool
clientKeyCertChainPath string
clientPrivateKeyPath string
clientKeyPassword string
clientTrustCertCollectionPath string
sslEnabled bool
tlsConfigBuilder TlsConfigBuilder
// the cert file of wss server which may contain server domain, server ip, the starting effective date, effective
// duration, the hash alg, the len of the private key.
......@@ -173,29 +145,8 @@ func WithClientSslEnabled(sslEnabled bool) ClientOption {
}
// @WithClientKeyCertChainPath sslConfig is tls config
func WithClientKeyCertChainPath(clientKeyCertChainPath string) ClientOption {
return func(o *ClientOptions) {
o.clientKeyCertChainPath = clientKeyCertChainPath
}
}
// @WithClientPrivateKeyPath sslConfig is tls config
func WithClientPrivateKeyPath(clientPrivateKeyPath string) ClientOption {
return func(o *ClientOptions) {
o.clientPrivateKeyPath = clientPrivateKeyPath
}
}
// @WithClientKeyPassword sslConfig is tls config
func WithClientKeyPassword(clientKeyPassword string) ClientOption {
return func(o *ClientOptions) {
o.clientKeyPassword = clientKeyPassword
}
}
// @WithClientTrustCertCollectionPath sslConfig is tls config
func WithClientTrustCertCollectionPath(clientTrustCertCollectionPath string) ClientOption {
func WithClientTlsConfigBuilder(tlsConfigBuilder TlsConfigBuilder) ClientOption {
return func(o *ClientOptions) {
o.clientTrustCertCollectionPath = clientTrustCertCollectionPath
o.tlsConfigBuilder = tlsConfigBuilder
}
}
......@@ -175,7 +175,7 @@ func (s *server) listenTCP() error {
}
} else {
if s.sslEnabled {
if sslConfig := s.loadSslConfig(); sslConfig != nil {
if sslConfig, err := s.tlsConfigBuilder.BuildTlsConfig(); err == nil && sslConfig != nil {
streamListener, err = tls.Listen("tcp", s.addr, sslConfig)
}
} else {
......@@ -492,38 +492,3 @@ func (s *server) Close() {
s.stop()
s.wg.Wait()
}
func (s *server) loadSslConfig() *tls.Config {
var (
err error
certPem []byte
certificate tls.Certificate
certPool *x509.CertPool
config *tls.Config
)
if certificate, err = tls.LoadX509KeyPair(s.serverKeyCertChainPath, s.serverPrivateKeyPath); err != nil {
panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err:%+v",
s.serverKeyCertChainPath, s.serverPrivateKeyPath, perrors.WithStack(err)))
}
config = &tls.Config{
InsecureSkipVerify: true, // do not verify peer cert
ClientAuth: tls.RequireAnyClientCert,
Certificates: []tls.Certificate{certificate},
}
if s.serverTrustCertCollectionPath != "" {
certPem, err = ioutil.ReadFile(s.serverTrustCertCollectionPath)
if err != nil {
panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err:%+v", s.serverTrustCertCollectionPath, perrors.WithStack(err)))
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
panic("failed to parse root certificate file")
}
config.ClientCAs = certPool
config.ClientAuth = tls.RequireAnyClientCert
config.InsecureSkipVerify = false
}
return config
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
)
import (
perrors "github.com/pkg/errors"
)
type TlsConfigBuilder interface {
BuildTlsConfig() (*tls.Config, error)
}
type ServerTlsConfigBuilder struct {
ServerKeyCertChainPath string
ServerPrivateKeyPath string
ServerKeyPassword string
ServerTrustCertCollectionPath string
}
func (s *ServerTlsConfigBuilder) BuildTlsConfig() (*tls.Config, error) {
var (
err error
certPem []byte
certificate tls.Certificate
certPool *x509.CertPool
config *tls.Config
)
if certificate, err = tls.LoadX509KeyPair(s.ServerKeyCertChainPath, s.ServerPrivateKeyPath); err != nil {
log.Error(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err:%+v",
s.ServerKeyCertChainPath, s.ServerPrivateKeyPath, perrors.WithStack(err)))
return nil, err
}
config = &tls.Config{
InsecureSkipVerify: true, // do not verify peer cert
ClientAuth: tls.RequireAnyClientCert,
Certificates: []tls.Certificate{certificate},
}
if s.ServerTrustCertCollectionPath != "" {
certPem, err = ioutil.ReadFile(s.ServerTrustCertCollectionPath)
if err != nil {
log.Error(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err:%+v", s.ServerTrustCertCollectionPath, perrors.WithStack(err)))
return nil, err
}
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(certPem); !ok {
log.Error("failed to parse root certificate file")
return nil, err
}
config.ClientCAs = certPool
config.ClientAuth = tls.RequireAnyClientCert
config.InsecureSkipVerify = false
}
return config, nil
}
type ClientTlsConfigBuilder struct {
ClientKeyCertChainPath string
ClientPrivateKeyPath string
ClientKeyPassword string
ClientTrustCertCollectionPath string
}
func (c *ClientTlsConfigBuilder) BuildTlsConfig() (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(c.ClientTrustCertCollectionPath, c.ClientPrivateKeyPath)
if err != nil {
log.Error(fmt.Sprintf("Unable to load X509 Key Pair %v", err))
return nil, err
}
certBytes, err := ioutil.ReadFile(c.ClientTrustCertCollectionPath)
if err != nil {
log.Error(fmt.Sprintf("Unable to read pem file: %s", c.ClientTrustCertCollectionPath))
return nil, err
}
clientCertPool := x509.NewCertPool()
ok := clientCertPool.AppendCertsFromPEM(certBytes)
if !ok {
log.Error("failed to parse root certificate")
return nil, err
}
return &tls.Config{
RootCAs: clientCertPool,
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment