Commit ec9ec7e2 authored by Tsaiilin's avatar Tsaiilin

修改 transfer 过程逻辑,修复在极限情况下 transfer过程不安全问题

parent bf888d5e
......@@ -13,7 +13,6 @@ import (
"virjar.com/majora-go/global"
"virjar.com/majora-go/log"
"virjar.com/majora-go/protocol"
"virjar.com/majora-go/safe"
"virjar.com/majora-go/trace"
)
......@@ -88,9 +87,9 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket, session getty
target = hosts[0]
} else {
target = string(ip)
recorder.RecordEvent(trace.DnsResolveEvent, fmt.Sprintf("Dns cache hit %s -> %s", hostPort[0], target))
}
recorder.RecordEvent(trace.DnsResolveEvent, fmt.Sprintf("Dns cache hit %s -> %s", hostPort[0], target))
conn, err := dialer.Dial(common.TCP, fmt.Sprintf("%s:%s", target, hostPort[1]))
if err != nil {
log.Error().Errorf("[handleConnect] %d->connect to %s->%s", packet.SerialNumber, packet.Extra, err.Error())
......@@ -103,9 +102,9 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket, session getty
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
t := NewTransfer(packet.SerialNumber, tcpConn, recorder)
t.SetTransferFunc(client.transfer)
t.Start()
t := NewTransfer(packet.SerialNumber, client, tcpConn, recorder)
t.SetTransferToUpstreamFunc(client.transferToUpstream)
t.SetTransferToDownstreamFunc(client.transferToDownstream)
client.AddTransfer(packet.SerialNumber, t, packet.Extra)
recorder.RecordEvent(trace.ConnectEvent, fmt.Sprintf("Connect to %s success, local: %s -> remote:%s (aaasn:%d)",
packet.Extra, tcpConn.LocalAddr(), tcpConn.RemoteAddr(), packet.SerialNumber))
......@@ -129,16 +128,14 @@ func (client *Client) handleConnect(packet *protocol.MajoraPacket, session getty
client.closeVirtualConnection(session, packet.SerialNumber)
return
} else {
safe.Go(func() {
client.handleUpStream(tcpConn, packet, session)
})
t.Start()
log.Run().Debugf("[handleConnect] %d->connect success to %s ", packet.SerialNumber, packet.Extra)
recorder.RecordEvent(trace.ConnectEvent, fmt.Sprintf("Replay natServer connect ready success (sn:%d)", packet.SerialNumber))
}
}
func decodeMap(data []byte) map[string]string {
result := make(map[string]string, 1)
result := make(map[string]string, 2)
var headerSize int8
err := binary.Read(bytes.NewBuffer(data[:1]), binary.BigEndian, &headerSize)
data = data[1:]
......@@ -176,50 +173,7 @@ func (client *Client) handleTransfer(packet *protocol.MajoraPacket, session gett
}
t.recorder.RecordEvent(trace.TransferEvent,
fmt.Sprintf("Receive transfer packet from natServer,start to be forward to target, len:%d (%d)", len(packet.Data), packet.SerialNumber))
t.Send(packet)
}
func (client *Client) handleUpStream(conn *net.TCPConn, packet *protocol.MajoraPacket, session getty.Session) {
traceRecorder := client.GetRecorder(packet.SerialNumber)
traceRecorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("Ready read from upstream (sn:%d)", packet.SerialNumber))
log.Run().Debugf("[handleUpStream] %d-> handleUpStream start...", packet.SerialNumber)
for {
buf := make([]byte, common.BufSize)
cnt, err := conn.Read(buf)
if err != nil {
opErr, ok := err.(*net.OpError)
if ok && opErr.Err.Error() == "i/o timeout" {
recorderMessage := fmt.Sprintf("Upstream deadDeadline start close (sn:%d)", packet.SerialNumber)
traceRecorder.RecordEvent(trace.UpStreamEvent, recorderMessage)
} else {
log.Run().Debugf("[handleUpStream] %d->read with error:%+v,l:%s->r:%s",
packet.SerialNumber, err, conn.LocalAddr(), conn.RemoteAddr())
recorderMessage := fmt.Sprintf("Read with l:%s->r:%s (sn:%d) ",
conn.LocalAddr(), conn.RemoteAddr(), packet.SerialNumber)
traceRecorder.RecordErrorEvent(trace.UpStreamEvent, recorderMessage, err)
}
client.OnClose(session, conn, packet.SerialNumber)
break
}
traceRecorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("read count: %d (sn:%d)",
cnt, packet.SerialNumber))
traceRecorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("Start write to natServer (sn:%d)", packet.SerialNumber))
pack := protocol.TypeTransfer.CreatePacket()
pack.Data = buf[0:cnt]
pack.SerialNumber = packet.SerialNumber
if _, _, err := session.WritePkg(pack, 0); err != nil {
log.Error().Errorf("[handleUpStream] %d-> write to server fail %+v", packet.SerialNumber, err.Error())
traceRecorder.RecordErrorEvent(trace.UpStreamEvent,
fmt.Sprintf("Write to natServer failed (sn:%d)", packet.SerialNumber), err)
client.OnClose(session, conn, packet.SerialNumber)
break
} else {
log.Run().Debugf("[handleUpStream] %d->success dataLen:%d ", packet.SerialNumber, string(packet.Data))
traceRecorder.RecordEvent(trace.UpStreamEvent,
fmt.Sprintf("Write to natServer success(sn:%d)", packet.SerialNumber))
}
}
t.TransferToUpstream(packet)
}
func (client *Client) handleDisconnectMessage(session getty.Session, packet *protocol.MajoraPacket) {
......@@ -300,7 +254,7 @@ func (client *Client) closeVirtualConnection(session getty.Session, serialNumber
log.Run().Warnf("[closeVirtualConnection] ->%d error %s session closed %v allCnt %d sendCnt %d",
serialNumber, err.Error(), session.IsClosed(), allCnt, sendCnt)
traceRecorder.RecordErrorEvent(trace.DisconnectEvent,
fmt.Sprintf("Send disconnect to natServer failed closed:%v allCnt %d sendCnt %d (sn:%d)",
fmt.Sprintf("send disconnect to natServer failed closed:%v allCnt %d sendCnt %d (sn:%d)",
session.IsClosed(), allCnt, sendCnt, serialNumber), err)
session.Close()
}
......@@ -324,7 +278,7 @@ func (client *Client) CloseAll() {
client.transferStore = sync.Map{}
}
func (client *Client) transfer(t *Transfer, p *protocol.MajoraPacket) {
func (client *Client) transferToUpstream(t *Transfer, p *protocol.MajoraPacket) {
cnt, err := t.upstreamConn.Write(p.Data)
if err != nil {
log.Error().Errorf("[handleTransfer] %d->write to upstream fail for %s", p.SerialNumber, err)
......@@ -346,3 +300,38 @@ func (client *Client) transfer(t *Transfer, p *protocol.MajoraPacket) {
traceMessage := fmt.Sprintf("transfer data success (%d)", p.SerialNumber)
t.recorder.RecordEvent(trace.TransferEvent, traceMessage)
}
func (client *Client) transferToDownstream(t *Transfer, data []byte, err error) {
if err != nil {
opErr, ok := err.(*net.OpError)
if ok && opErr.Err.Error() == "i/o timeout" {
recorderMessage := fmt.Sprintf("Upstream deadDeadline start close (sn:%d)", t.serialNumber)
t.recorder.RecordEvent(trace.UpStreamEvent, recorderMessage)
} else {
log.Run().Debugf("[handleUpStream] %d->read with error:%+v,l:%s->r:%s",
t.serialNumber, err, t.upstreamConn.LocalAddr(), t.upstreamConn.RemoteAddr())
recorderMessage := fmt.Sprintf("Read with l:%s->r:%s (sn:%d) ",
t.upstreamConn.LocalAddr(), t.upstreamConn.RemoteAddr(), t.serialNumber)
t.recorder.RecordErrorEvent(trace.UpStreamEvent, recorderMessage, err)
}
t.client.OnClose(t.client.session, t.upstreamConn, t.serialNumber)
} else {
t.recorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("read count: %d (sn:%d)",
len(data), t.serialNumber))
t.recorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("Start write to natServer (sn:%d)", t.serialNumber))
pack := protocol.TypeTransfer.CreatePacket()
pack.Data = data
pack.SerialNumber = t.serialNumber
if _, _, err := t.client.session.WritePkg(pack, 0); err != nil {
log.Error().Errorf("[handleUpStream] %d-> write to server fail %+v", t.client, err.Error())
t.recorder.RecordErrorEvent(trace.UpStreamEvent,
fmt.Sprintf("Write to natServer failed (sn:%d)", t.serialNumber), err)
t.client.OnClose(t.client.session, t.upstreamConn, t.serialNumber)
} else {
log.Run().Debugf("[handleUpStream] %d->success dataLen:%d ", t.serialNumber, len(pack.Data))
t.recorder.RecordEvent(trace.UpStreamEvent,
fmt.Sprintf("Write to natServer success(sn:%d)", pack.SerialNumber))
}
}
}
......@@ -75,7 +75,7 @@ func (client *Client) Redial(tag string) {
if !client.config.Redial.Valid() {
return
}
log.Run().Infof("[Redial %s] Send offline message", tag)
log.Run().Infof("[Redial %s] seed offline message", tag)
if _, _, err := client.session.WritePkg(OfflinePacket, 0); err != nil {
log.Run().Errorf("[Redial %s] write offline to server error %s", tag, err.Error())
}
......
......@@ -6,24 +6,29 @@ import (
"net"
"sync"
"time"
"virjar.com/majora-go/common"
"virjar.com/majora-go/log"
"virjar.com/majora-go/protocol"
"virjar.com/majora-go/safe"
"virjar.com/majora-go/trace"
)
type Transfer struct {
serialNumber int64
upstreamConn *net.TCPConn
recorder trace.Recorder
transferChan chan *protocol.MajoraPacket
transferFunc func(t *Transfer, p *protocol.MajoraPacket)
once sync.Once
cancel chan struct{}
serialNumber int64
client *Client
upstreamConn *net.TCPConn
recorder trace.Recorder
transferChan chan *protocol.MajoraPacket
transferToUpstreamFunc func(t *Transfer, p *protocol.MajoraPacket)
transferToDownstreamFunc func(t *Transfer, data []byte, err error)
once sync.Once
cancel chan struct{}
}
func NewTransfer(serialNumber int64, conn *net.TCPConn, recorder trace.Recorder) *Transfer {
func NewTransfer(serialNumber int64, client *Client, conn *net.TCPConn, recorder trace.Recorder) *Transfer {
return &Transfer{
serialNumber: serialNumber,
client: client,
upstreamConn: conn,
recorder: recorder,
transferChan: make(chan *protocol.MajoraPacket, 10),
......@@ -32,28 +37,51 @@ func NewTransfer(serialNumber int64, conn *net.TCPConn, recorder trace.Recorder)
}
}
func (t *Transfer) SetTransferFunc(f func(t *Transfer, p *protocol.MajoraPacket)) {
t.transferFunc = f
func (t *Transfer) SetTransferToUpstreamFunc(f func(t *Transfer, p *protocol.MajoraPacket)) {
t.transferToUpstreamFunc = f
}
func (t *Transfer) Send(p *protocol.MajoraPacket) {
func (t *Transfer) SetTransferToDownstreamFunc(f func(t *Transfer, data []byte, err error)) {
t.transferToDownstreamFunc = f
}
func (t *Transfer) TransferToUpstream(p *protocol.MajoraPacket) {
t.transferChan <- p
}
func (t *Transfer) Start() {
if t.transferFunc == nil {
panic(errors.New("transferFunc is nil"))
if t.transferToUpstreamFunc == nil {
panic(errors.New("transferToUpstreamFunc is nil"))
}
if t.transferToDownstreamFunc == nil {
panic(errors.New("transferToDownstreamFunc is nil"))
}
safe.Go(func() {
for {
select {
case p := <-t.transferChan:
t.transferFunc(t, p)
t.transferToUpstreamFunc(t, p)
case <-t.cancel:
return
}
}
})
safe.Go(func() {
traceRecorder := t.recorder
traceRecorder.RecordEvent(trace.UpStreamEvent, fmt.Sprintf("Ready read from upstream (sn:%d)", t.serialNumber))
log.Run().Debugf("[handleUpStream] %d-> handleUpStream start...", t.serialNumber)
for {
buf := make([]byte, common.BufSize)
cnt, err := t.upstreamConn.Read(buf)
t.transferToDownstreamFunc(t, buf[0:cnt], err)
if err != nil {
break
}
}
})
}
func (t *Transfer) Close() {
......
tunnel_addr: majora-vps-zj.virjar.com:5879
tunnel_addr: 127.0.0.1:5879
dns_server: 114.114.114.114:53
daemon: true
#daemon: true
log_level: debug
log_path: ./majora-log/
reconn_intervalz: 5s
......@@ -8,11 +8,11 @@ net_check_interval: 5s
dns_cache_duration: 10m
net_check_url: https://www.baidu.com[extra]
redial:
command: /bin/bash
exec_path: /root/ppp_redial.sh
redial_duration: 5m
wait_time: 10s
#redial:
# command: /bin/bash
# exec_path: /root/ppp_redial.sh
# redial_duration: 5m
# wait_time: 10s
extra:
account: superman
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment