diff options
Diffstat (limited to 'weed/udptransfer/endpoint.go')
| -rw-r--r-- | weed/udptransfer/endpoint.go | 339 |
1 files changed, 339 insertions, 0 deletions
diff --git a/weed/udptransfer/endpoint.go b/weed/udptransfer/endpoint.go new file mode 100644 index 000000000..d19d1a4f5 --- /dev/null +++ b/weed/udptransfer/endpoint.go @@ -0,0 +1,339 @@ +package udptransfer + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "math/rand" + "net" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/cloudflare/golibs/bytepool" +) + +const ( + _SO_BUF_SIZE = 8 << 20 +) + +var ( + bpool bytepool.BytePool +) + +type Params struct { + LocalAddr string + Bandwidth int64 + Mtu int + IsServ bool + FastRetransmit bool + FlatTraffic bool + EnablePprof bool + Stacktrace bool + Debug int +} + +type connID struct { + lid uint32 + rid uint32 +} + +type Endpoint struct { + udpconn *net.UDPConn + state int32 + idSeq uint32 + isServ bool + listenChan chan *Conn + lRegistry map[uint32]*Conn + rRegistry map[string][]uint32 + mlock sync.RWMutex + timeout *time.Timer + params Params +} + +func (c *connID) setRid(b []byte) { + c.rid = binary.BigEndian.Uint32(b[_MAGIC_SIZE+6:]) +} + +func init() { + bpool.Init(0, 2000) + rand.Seed(NowNS()) +} + +func NewEndpoint(p *Params) (*Endpoint, error) { + set_debug_params(p) + if p.Bandwidth <= 0 || p.Bandwidth > 100 { + return nil, fmt.Errorf("bw->(0,100]") + } + conn, err := net.ListenPacket("udp", p.LocalAddr) + if err != nil { + return nil, err + } + e := &Endpoint{ + udpconn: conn.(*net.UDPConn), + idSeq: 1, + isServ: p.IsServ, + listenChan: make(chan *Conn, 1), + lRegistry: make(map[uint32]*Conn), + rRegistry: make(map[string][]uint32), + timeout: time.NewTimer(0), + params: *p, + } + if e.isServ { + e.state = _S_EST0 + } else { // client + e.state = _S_EST1 + e.idSeq = uint32(rand.Int31()) + } + e.params.Bandwidth = p.Bandwidth << 20 // mbps to bps + e.udpconn.SetReadBuffer(_SO_BUF_SIZE) + go e.internal_listen() + return e, nil +} + +func (e *Endpoint) internal_listen() { + const rtmo = time.Duration(30*time.Second) + var id connID + for { + //var buf = make([]byte, 1600) + var buf = bpool.Get(1600) + e.udpconn.SetReadDeadline(time.Now().Add(rtmo)) + n, addr, err := e.udpconn.ReadFromUDP(buf) + if err == nil && n >= _AH_SIZE { + buf = buf[:n] + e.getConnID(&id, buf) + + switch id.lid { + case 0: // new connection + if e.isServ { + go e.acceptNewConn(id, addr, buf) + } else { + dumpb("drop", buf) + } + + case _INVALID_SEQ: + dumpb("drop invalid", buf) + + default: // old connection + e.mlock.RLock() + conn := e.lRegistry[id.lid] + e.mlock.RUnlock() + if conn != nil { + e.dispatch(conn, buf) + } else { + e.resetPeer(addr, id) + dumpb("drop null", buf) + } + } + + } else if err != nil { + // idle process + if nerr, y := err.(net.Error); y && nerr.Timeout() { + e.idleProcess() + continue + } + // other errors + if atomic.LoadInt32(&e.state) == _S_FIN { + return + } else { + log.Println("Error: read sock", err) + } + } + } +} + +func (e *Endpoint) idleProcess() { + // recycle/shrink memory + bpool.Drain() + e.mlock.Lock() + defer e.mlock.Unlock() + // reset urgent + for _, c := range e.lRegistry { + c.outlock.Lock() + if c.outQ.size() == 0 && c.urgent != 0 { + c.urgent = 0 + } + c.outlock.Unlock() + } +} + +func (e *Endpoint) Dial(addr string) (*Conn, error) { + rAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + e.mlock.Lock() + e.idSeq++ + id := connID{e.idSeq, 0} + conn := NewConn(e, rAddr, id) + e.lRegistry[id.lid] = conn + e.mlock.Unlock() + if atomic.LoadInt32(&e.state) != _S_FIN { + err = conn.initConnection(nil) + return conn, err + } + return nil, io.EOF +} + +func (e *Endpoint) acceptNewConn(id connID, addr *net.UDPAddr, buf []byte) { + rKey := addr.String() + e.mlock.Lock() + // map: remoteAddr => remoteConnID + // filter duplicated syn packets + if newArr := insertRid(e.rRegistry[rKey], id.rid); newArr != nil { + e.rRegistry[rKey] = newArr + } else { + e.mlock.Unlock() + log.Println("Warn: duplicated connection", addr) + return + } + e.idSeq++ + id.lid = e.idSeq + conn := NewConn(e, addr, id) + e.lRegistry[id.lid] = conn + e.mlock.Unlock() + err := conn.initConnection(buf) + if err == nil { + select { + case e.listenChan <- conn: + case <-time.After(_10ms): + log.Println("Warn: no listener") + } + } else { + e.removeConn(id, addr) + log.Println("Error: init_connection", addr, err) + } +} + +func (e *Endpoint) removeConn(id connID, addr *net.UDPAddr) { + e.mlock.Lock() + delete(e.lRegistry, id.lid) + rKey := addr.String() + if newArr := deleteRid(e.rRegistry[rKey], id.rid); newArr != nil { + if len(newArr) > 0 { + e.rRegistry[rKey] = newArr + } else { + delete(e.rRegistry, rKey) + } + } + e.mlock.Unlock() +} + +// net.Listener +func (e *Endpoint) Close() error { + state := atomic.LoadInt32(&e.state) + if state > 0 && atomic.CompareAndSwapInt32(&e.state, state, _S_FIN) { + err := e.udpconn.Close() + e.lRegistry = nil + e.rRegistry = nil + select { // release listeners + case e.listenChan <- nil: + default: + } + return err + } + return nil +} + +// net.Listener +func (e *Endpoint) Addr() net.Addr { + return e.udpconn.LocalAddr() +} + +// net.Listener +func (e *Endpoint) Accept() (net.Conn, error) { + if atomic.LoadInt32(&e.state) == _S_EST0 { + return <-e.listenChan, nil + } else { + return nil, io.EOF + } +} + +func (e *Endpoint) Listen() *Conn { + if atomic.LoadInt32(&e.state) == _S_EST0 { + return <-e.listenChan + } else { + return nil + } +} + +// tmo in MS +func (e *Endpoint) ListenTimeout(tmo int64) *Conn { + if tmo <= 0 { + return e.Listen() + } + if atomic.LoadInt32(&e.state) == _S_EST0 { + select { + case c := <-e.listenChan: + return c + case <-NewTimerChan(tmo): + } + } + return nil +} + +func (e *Endpoint) getConnID(idPtr *connID, buf []byte) { + // TODO determine magic header + magicAndLen := binary.BigEndian.Uint64(buf) + if int(magicAndLen&0xFFff) == len(buf) { + id := binary.BigEndian.Uint64(buf[_MAGIC_SIZE+2:]) + idPtr.lid = uint32(id >> 32) + idPtr.rid = uint32(id) + } else { + idPtr.lid = _INVALID_SEQ + } +} + +func (e *Endpoint) dispatch(c *Conn, buf []byte) { + e.timeout.Reset(30*time.Millisecond) + select { + case c.evRecv <- buf: + case <-e.timeout.C: + log.Println("Warn: dispatch packet failed") + } +} + +func (e *Endpoint) resetPeer(addr *net.UDPAddr, id connID) { + pk := &packet{flag: _F_FIN | _F_RESET} + buf := nodeOf(pk).marshall(id) + e.udpconn.WriteToUDP(buf, addr) +} + +type u32Slice []uint32 + +func (p u32Slice) Len() int { return len(p) } +func (p u32Slice) Less(i, j int) bool { return p[i] < p[j] } +func (p u32Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// if the rid is not existed in array then insert it return new array +func insertRid(array []uint32, rid uint32) []uint32 { + if len(array) > 0 { + pos := sort.Search(len(array), func(n int) bool { + return array[n] >= rid + }) + if pos < len(array) && array[pos] == rid { + return nil + } + } + array = append(array, rid) + sort.Sort(u32Slice(array)) + return array +} + +// if rid was existed in array then delete it return new array +func deleteRid(array []uint32, rid uint32) []uint32 { + if len(array) > 0 { + pos := sort.Search(len(array), func(n int) bool { + return array[n] >= rid + }) + if pos < len(array) && array[pos] == rid { + newArray := make([]uint32, len(array)-1) + n := copy(newArray, array[:pos]) + copy(newArray[n:], array[pos+1:]) + return newArray + } + } + return nil +} |
