Commit 942eb493 authored by Tsaiilin's avatar Tsaiilin

update

parent 25464140
...@@ -19,7 +19,6 @@ type Client struct { ...@@ -19,7 +19,6 @@ type Client struct {
session getty.Session session getty.Session
transferStore sync.Map transferStore sync.Map
dnsCache *freecache.Cache dnsCache *freecache.Cache
close chan struct{}
} }
func NewClientWithConf(cfg *model.Configure, host string, port int) *Client { func NewClientWithConf(cfg *model.Configure, host string, port int) *Client {
...@@ -49,3 +48,12 @@ func NewCli(cfg *model.Configure, host string, port int) *Client { ...@@ -49,3 +48,12 @@ func NewCli(cfg *model.Configure, host string, port int) *Client {
func (client *Client) StartUp() { func (client *Client) StartUp() {
client.connect() client.connect()
} }
func (client *Client) Close() {
client.transferStore.Range(func(key, value interface{}) bool {
t := value.(*Transfer)
t.Close()
return true
})
client.transferStore = sync.Map{}
}
...@@ -29,21 +29,21 @@ func (c *ClusterClient) Start() { ...@@ -29,21 +29,21 @@ func (c *ClusterClient) Start() {
timer.Reset(5 * time.Minute) timer.Reset(5 * time.Minute)
} }
}) })
if global.Config.Redial.Valid() { //if global.Config.Redial.Valid() {
_ = taskPool.Submit(func() { // _ = taskPool.Submit(func() {
// 加上随机 防止vps在同时间重启 // // 加上随机 防止vps在同时间重启
duration := c.randomDuration() // duration := c.randomDuration()
log.Run().Infof("Redial interval %+v", duration) // log.Run().Infof("Redial interval %+v", duration)
var timer = time.NewTimer(duration) // var timer = time.NewTimer(duration)
for { // for {
<-timer.C // <-timer.C
c.StartRedial("cron", true) // c.StartRedial("cron", true)
duration = c.randomDuration() // duration = c.randomDuration()
log.Run().Infof("Redial interval %+v", duration) // log.Run().Infof("Redial interval %+v", duration)
timer.Reset(duration) // timer.Reset(duration)
} // }
}) // })
} //}
} }
func (c *ClusterClient) connectNatServers() { func (c *ClusterClient) connectNatServers() {
...@@ -88,7 +88,6 @@ func (c *ClusterClient) connectNatServers() { ...@@ -88,7 +88,6 @@ func (c *ClusterClient) connectNatServers() {
log.Error().Error("[connectNatServers] client already remove") log.Error().Error("[connectNatServers] client already remove")
} }
needCloseClient := load.(*Client) needCloseClient := load.(*Client)
needCloseClient.CloseAll()
needCloseClient.natTunnel.Close() needCloseClient.natTunnel.Close()
return true return true
}) })
...@@ -118,7 +117,6 @@ func (c *ClusterClient) StartRedial(tag string, replay bool) { ...@@ -118,7 +117,6 @@ func (c *ClusterClient) StartRedial(tag string, replay bool) {
log.Run().Info("[Redial %s ] start close local session", tag) log.Run().Info("[Redial %s ] start close local session", tag)
c.clients.Range(func(host, c interface{}) bool { c.clients.Range(func(host, c interface{}) bool {
client, _ := c.(*Client) client, _ := c.(*Client)
client.close = make(chan struct{})
client.natTunnel.Close() client.natTunnel.Close()
return true return true
}) })
...@@ -126,11 +124,7 @@ func (c *ClusterClient) StartRedial(tag string, replay bool) { ...@@ -126,11 +124,7 @@ func (c *ClusterClient) StartRedial(tag string, replay bool) {
c.clients.Range(func(host, c interface{}) bool { c.clients.Range(func(host, c interface{}) bool {
client, _ := c.(*Client) client, _ := c.(*Client)
_ = taskPool.Submit(func() { _ = taskPool.Submit(func() {
select {
case <-client.close:
client.connect() client.connect()
client.close = nil
}
}) })
return true return true
}) })
......
This diff is collapsed.
...@@ -31,20 +31,12 @@ func (m *MajoraEventListener) OnOpen(session getty.Session) error { ...@@ -31,20 +31,12 @@ func (m *MajoraEventListener) OnOpen(session getty.Session) error {
} }
func (m *MajoraEventListener) OnClose(session getty.Session) { func (m *MajoraEventListener) OnClose(session getty.Session) {
_ = taskPool.Submit(func() { m.client.Close()
log.Run().Infof("OnClose-> session closed %v", session.IsClosed()) log.Run().Infof("OnClose-> session closed %v", session.IsClosed())
m.client.CloseAll()
if m.client.close != nil {
m.client.close <- struct{}{}
}
})
} }
func (m *MajoraEventListener) OnError(session getty.Session, err error) { func (m *MajoraEventListener) OnError(session getty.Session, err error) {
_ = taskPool.Submit(func() { log.Error().Errorf("OnError %s", err.Error())
log.Error().Errorf("OnError %s", err.Error())
m.client.CloseAll()
})
} }
func (m *MajoraEventListener) OnCron(session getty.Session) { func (m *MajoraEventListener) OnCron(session getty.Session) {
......
...@@ -3,6 +3,7 @@ package client ...@@ -3,6 +3,7 @@ package client
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/adamweixuan/getty"
"net" "net"
"sync" "sync"
"time" "time"
...@@ -12,87 +13,150 @@ import ( ...@@ -12,87 +13,150 @@ import (
"virjar.com/majora-go/trace" "virjar.com/majora-go/trace"
) )
type TransferListener interface {
OnUpStreamConnectSuccess(t *Transfer)
OnUpStreamConnectFailed(t *Transfer, err error)
OnUpStreamWriteError(t *Transfer, err error)
OnUpStreamReadError(t *Transfer, err error)
OnDownStreamWriteError(t *Transfer, err error)
}
type Transfer struct { type Transfer struct {
serialNumber int64 serialNumber int64
client *Client target string
upstreamConn *net.TCPConn upstreamConn *net.TCPConn
recorder trace.Recorder session getty.Session
transferChan chan *protocol.MajoraPacket recorder trace.Recorder
transferToUpstreamFunc func(t *Transfer, p *protocol.MajoraPacket) transferChan chan []byte
transferToDownstreamFunc func(t *Transfer, data []byte, err error) once sync.Once
once sync.Once listener TransferListener
cancel chan struct{} cancel chan struct{}
} }
func NewTransfer(serialNumber int64, client *Client, conn *net.TCPConn, recorder trace.Recorder) *Transfer { func NewTransfer(serialNumber int64, target string, session getty.Session, listener TransferListener, recorder trace.Recorder) *Transfer {
return &Transfer{ return &Transfer{
serialNumber: serialNumber, serialNumber: serialNumber,
client: client, target: target,
upstreamConn: conn, session: session,
recorder: recorder, recorder: recorder,
transferChan: make(chan *protocol.MajoraPacket, 10), listener: listener,
transferChan: make(chan []byte, 10),
once: sync.Once{}, once: sync.Once{},
cancel: make(chan struct{}, 0), cancel: make(chan struct{}, 0),
} }
} }
func (t *Transfer) SetTransferToUpstreamFunc(f func(t *Transfer, p *protocol.MajoraPacket)) { func (t *Transfer) transferToUpstream(data []byte) {
t.transferToUpstreamFunc = f cnt, err := t.upstreamConn.Write(data)
if err != nil {
log.Error().Errorf("[handleTransfer] %d->write to upstream fail for %s", t.serialNumber, err)
traceMessage := fmt.Sprintf("Write to upstream failed (%d)", t.serialNumber)
t.recorder.RecordErrorEvent(trace.TransferEvent, traceMessage, err)
t.listener.OnUpStreamWriteError(t, err)
return
}
if cnt != len(data) {
log.Error().Errorf("[handleTransfer] %d-> write not all data for expect->%d/%d",
t.session, len(data), cnt)
traceMessage := fmt.Sprintf("Write not all data for expect -> %d/%d (sn:%d)", len(data), cnt, t.serialNumber)
t.recorder.RecordErrorEvent(trace.TransferEvent, traceMessage, nil)
t.listener.OnUpStreamWriteError(t, err)
return
}
log.Run().Debugf("[handleTransfer] %d-> success dataLen: %d", t.serialNumber, len(data))
traceMessage := fmt.Sprintf("transfer data success (%d)", t.serialNumber)
t.recorder.RecordEvent(trace.TransferEvent, traceMessage)
} }
func (t *Transfer) SetTransferToDownstreamFunc(f func(t *Transfer, data []byte, err error)) { func (t *Transfer) transferToDownStream() {
t.transferToDownstreamFunc = f _ = taskPool.Submit(func() {
traceRecorder := t.recorder
traceRecorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("Ready read from upstream (sn:%d)", t.serialNumber))
log.Run().Debugf("[handleUpStream] %d-> handleUpStream start...", t.serialNumber)
for {
buf := make([]byte, common.BufSize)
cnt, err := t.upstreamConn.Read(buf)
if t.session.IsClosed() {
t.recorder.RecordErrorEvent(trace.UpStreamEvent, fmt.Sprintf("DownStream closed(sn:%d)", t.serialNumber),
errors.New("closed"))
return
}
if err != nil {
log.Run().Debugf("[handleUpStream] %d->read with error:%+v,l:%s->r:%s",
t.serialNumber, err, t.upstreamConn.LocalAddr(), t.upstreamConn.RemoteAddr())
recorderMessage := fmt.Sprintf("Read with l:%s->r:%s (sn:%d) ",
t.upstreamConn.LocalAddr(), t.upstreamConn.RemoteAddr(), t.serialNumber)
t.recorder.RecordErrorEvent(trace.UpStreamEvent, recorderMessage, err)
t.listener.OnUpStreamReadError(t, err)
break
} else {
t.recorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("read count: %d (sn:%d)",
cnt, t.serialNumber))
t.recorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("Start write to natServer (sn:%d)", t.serialNumber))
pack := protocol.TypeTransfer.CreatePacket()
pack.Data = buf[0:cnt]
pack.SerialNumber = t.serialNumber
if _, _, err := t.session.WritePkg(pack, 0); err != nil {
log.Error().Errorf("[handleUpStream] %d-> write to server fail %+v", t.serialNumber, err.Error())
t.recorder.RecordErrorEvent(trace.UpStreamEvent,
fmt.Sprintf("Write to natServer failed (sn:%d)", t.serialNumber), err)
t.listener.OnDownStreamWriteError(t, err)
} else {
log.Run().Debugf("[handleUpStream] %d->success dataLen:%d ", t.serialNumber, len(pack.Data))
t.recorder.RecordEvent(trace.UpStreamEvent,
fmt.Sprintf("Write to natServer success(sn:%d)", pack.SerialNumber))
}
}
}
})
} }
func (t *Transfer) TransferToUpstream(p *protocol.MajoraPacket) { func (t *Transfer) TransferToUpstream(data []byte) {
t.transferChan <- p t.transferChan <- data
} }
func (t *Transfer) Start() { func (t *Transfer) Start() {
if t.transferToUpstreamFunc == nil {
panic(errors.New("transferToUpstreamFunc is nil"))
}
if t.transferToDownstreamFunc == nil { dialer := net.Dialer{
panic(errors.New("transferToDownstreamFunc is nil")) Timeout: common.UpstreamTimeout,
}
conn, err := dialer.Dial(common.TCP, t.target)
if err != nil {
log.Error().Errorf("[handleConnect] %d->connect to %s->%s", t.serialNumber, t.target, err.Error())
t.recorder.RecordErrorEvent(trace.ConnectEvent,
fmt.Sprintf("Connect to %s failed (sn:%d)", t.target, t.serialNumber), err)
t.listener.OnUpStreamConnectFailed(t, err)
return
} }
t.upstreamConn = conn.(*net.TCPConn)
_ = t.upstreamConn.SetDeadline(time.Now().Add(45 * time.Second))
_ = t.upstreamConn.SetNoDelay(true)
_ = t.upstreamConn.SetKeepAlive(true)
t.listener.OnUpStreamConnectSuccess(t)
_ = taskPool.Submit(func() { _ = taskPool.Submit(func() {
for { for {
select { select {
case p := <-t.transferChan: case data := <-t.transferChan:
t.transferToUpstreamFunc(t, p) t.transferToUpstream(data)
case <-t.cancel: case <-t.cancel:
return return
} }
} }
}) })
_ = taskPool.Submit(func() { _ = taskPool.Submit(t.transferToDownStream)
traceRecorder := t.recorder
traceRecorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("Ready read from upstream (sn:%d)", t.serialNumber))
log.Run().Debugf("[handleUpStream] %d-> handleUpStream start...", t.serialNumber)
for {
buf := make([]byte, common.BufSize)
cnt, err := t.upstreamConn.Read(buf)
t.transferToDownstreamFunc(t, buf[0:cnt], err)
if err != nil {
break
}
}
})
} }
func (t *Transfer) Close() { func (t *Transfer) Close() {
t.once.Do(func() { t.once.Do(func() {
readDeadLine := time.Now().Add(3 * time.Millisecond) err := t.upstreamConn.Close()
t.recorder.RecordEvent(trace.DisconnectEvent, fmt.Sprintf("Set upstream read deadline:%s (sn:%d)",
readDeadLine.Format("2006-01-02 15:04:05.000000"), t.serialNumber))
err := t.upstreamConn.SetReadDeadline(readDeadLine)
if err != nil { if err != nil {
t.recorder.RecordErrorEvent(trace.DisconnectEvent, t.recorder.RecordErrorEvent(trace.DisconnectEvent,
fmt.Sprintf("Set upstream read deadline failed (sn:%d)", t.serialNumber), err) fmt.Sprintf("Set upstream read deadline failed (sn:%d)", t.serialNumber), err)
_ = t.upstreamConn.Close()
} }
close(t.cancel) close(t.cancel)
}) })
......
tunnel_addr: 127.0.0.1:5879 tunnel_addr: majora-vps.virjar.com:5879
dns_server: 114.114.114.114:53 dns_server: 114.114.114.114:53
#daemon: true #daemon: true
log_level: info log_level: debug
log_path: ./majora-log/ log_path: ./majora-log/
reconn_intervalz: 5s reconn_intervalz: 5s
net_check_interval: 5s #net_check_interval: 5s
dns_cache_duration: 10m dns_cache_duration: 10m
net_check_url: https://www.baidu.com net_check_url: https://www.baidu.com
......
...@@ -4,15 +4,15 @@ log_level: info ...@@ -4,15 +4,15 @@ log_level: info
log_path: ./majora-log/ log_path: ./majora-log/
daemon: true daemon: true
reconn_interval: 5s reconn_interval: 5s
net_check_interval: 5s #net_check_interval: 5s
net_check_url: https://www.baidu.com net_check_url: https://www.baidu.com
dns_cache_duration: 10m dns_cache_duration: 10m
redial: #redial:
command: /bin/bash # command: /bin/bash
exec_path: /root/ppp_redial.sh # exec_path: /root/ppp_redial.sh
redial_duration: 5m # redial_duration: 5m
wait_time: 15s # wait_time: 15s
extra: extra:
account: superman account: superman
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment