diff options
Diffstat (limited to 'weed/udptransfer/conn.go')
| -rw-r--r-- | weed/udptransfer/conn.go | 715 |
1 files changed, 715 insertions, 0 deletions
diff --git a/weed/udptransfer/conn.go b/weed/udptransfer/conn.go new file mode 100644 index 000000000..e2eca49da --- /dev/null +++ b/weed/udptransfer/conn.go @@ -0,0 +1,715 @@ +package udptransfer + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "net" + "time" +) + +const ( + _MAX_RETRIES = 6 + _MIN_RTT = 8 + _MIN_RTO = 30 + _MIN_ATO = 2 + _MAX_ATO = 10 + _MIN_SWND = 10 + _MAX_SWND = 960 +) + +const ( + _VACK_SCHED = iota + 1 + _VACK_QUICK + _VACK_MUST + _VSWND_ACTIVE + _VRETR_IMMED +) + +const ( + _RETR_REST = -1 + _CLOSE = 0xff +) + +var debug int + +func nodeOf(pk *packet) *qNode { + return &qNode{packet: pk} +} + +func (c *Conn) internalRecvLoop() { + defer func() { + // avoid send to closed channel while some replaying + // data packets were received in shutting down. + _ = recover() + }() + var buf, body []byte + for { + select { + case buf = <-c.evRecv: + if buf != nil { + body = buf[_TH_SIZE:] + } else { // shutdown + return + } + } + pk := new(packet) + // keep the original buffer, so we could recycle it in future + pk.buffer = buf + unmarshall(pk, body) + if pk.flag&_F_SACK != 0 { + c.processSAck(pk) + continue + } + if pk.flag&_F_ACK != 0 { + c.processAck(pk) + } + if pk.flag&_F_DATA != 0 { + c.insertData(pk) + } else if pk.flag&_F_FIN != 0 { + if pk.flag&_F_RESET != 0 { + go c.forceShutdownWithLock() + } else { + go c.closeR(pk) + } + } + } +} + +func (c *Conn) internalSendLoop() { + var timer = time.NewTimer(time.Duration(c.rtt) * time.Millisecond) + for { + select { + case v := <-c.evSWnd: + switch v { + case _VRETR_IMMED: + c.outlock.Lock() + c.retransmit2() + c.outlock.Unlock() + case _VSWND_ACTIVE: + timer.Reset(time.Duration(c.rtt) * time.Millisecond) + case _CLOSE: + return + } + case <-timer.C: // timeout yet + var notifySender bool + c.outlock.Lock() + rest, _ := c.retransmit() + switch rest { + case _RETR_REST, 0: // nothing to send + if c.outQ.size() > 0 { + timer.Reset(time.Duration(c.rtt) * time.Millisecond) + } else { + timer.Stop() + // avoid sender blocking + notifySender = true + } + default: // recent rto point + timer.Reset(time.Duration(minI64(rest, c.rtt)) * time.Millisecond) + } + c.outlock.Unlock() + if notifySender { + select { + case c.evSend <- 1: + default: + } + } + } + } +} + +func (c *Conn) internalAckLoop() { + // var ackTimer = time.NewTicker(time.Duration(c.ato)) + var ackTimer = time.NewTimer(time.Duration(c.ato) * time.Millisecond) + var lastAckState byte + for { + var v byte + select { + case <-ackTimer.C: + // may cause sending duplicated ack if ato>rtt + v = _VACK_QUICK + case v = <-c.evAck: + ackTimer.Reset(time.Duration(c.ato) * time.Millisecond) + state := lastAckState + lastAckState = v + if state != v { + if v == _CLOSE { + return + } + v = _VACK_MUST + } + } + c.inlock.Lock() + if pkAck := c.makeAck(v); pkAck != nil { + c.internalWrite(nodeOf(pkAck)) + } + c.inlock.Unlock() + } +} + +func (c *Conn) retransmit() (rest int64, count int32) { + var now, rto = Now(), c.rto + var limit = c.cwnd + for item := c.outQ.head; item != nil && limit > 0; item = item.next { + if item.scnt != _SENT_OK { // ACKed has scnt==-1 + diff := now - item.sent + if diff > rto { // already rto + c.internalWrite(item) + count++ + } else { + // continue search next min rto duration + if rest > 0 { + rest = minI64(rest, rto-diff+1) + } else { + rest = rto - diff + 1 + } + limit-- + } + } + } + c.outDupCnt += int(count) + if count > 0 { + shrcond := (c.fastRetransmit && count > maxI32(c.cwnd>>5, 4)) || (!c.fastRetransmit && count > c.cwnd>>3) + if shrcond && now-c.lastShrink > c.rto { + log.Printf("shrink cwnd from=%d to=%d s/4=%d", c.cwnd, c.cwnd>>1, c.swnd>>2) + c.lastShrink = now + // shrink cwnd and ensure cwnd >= swnd/4 + if c.cwnd > c.swnd>>1 { + c.cwnd >>= 1 + } + } + } + if c.outQ.size() > 0 { + return + } + return _RETR_REST, 0 +} + +func (c *Conn) retransmit2() (count int32) { + var limit, now = minI32(c.outPending>>4, 8), Now() + var fRtt = c.rtt + if now-c.lastShrink > c.rto { + fRtt += maxI64(c.rtt>>4, 1) + } else { + fRtt += maxI64(c.rtt>>1, 2) + } + for item := c.outQ.head; item != nil && count < limit; item = item.next { + if item.scnt != _SENT_OK { // ACKed has scnt==-1 + if item.miss >= 3 && now-item.sent >= fRtt { + item.miss = 0 + c.internalWrite(item) + count++ + } + } + } + c.fRCnt += int(count) + c.outDupCnt += int(count) + return +} + +func (c *Conn) inputAndSend(pk *packet) error { + item := &qNode{packet: pk} + if c.mySeq&3 == 1 { + c.tSlotT0 = NowNS() + } + c.outlock.Lock() + // inflight packets exceeds cwnd + // inflight includes: 1, unacked; 2, missed + for c.outPending >= c.cwnd+c.missed { + c.outlock.Unlock() + if c.wtmo > 0 { + var tmo int64 + tmo, c.wtmo = c.wtmo, 0 + select { + case v := <-c.evSend: + if v == _CLOSE { + return io.EOF + } + case <-NewTimerChan(tmo): + return ErrIOTimeout + } + } else { + if v := <-c.evSend; v == _CLOSE { + return io.EOF + } + } + c.outlock.Lock() + } + c.outPending++ + c.outPkCnt++ + c.mySeq++ + pk.seq = c.mySeq + c.outQ.appendTail(item) + c.internalWrite(item) + c.outlock.Unlock() + // active resending timer, must blocking + c.evSWnd <- _VSWND_ACTIVE + if c.mySeq&3 == 0 && c.flatTraffic { + // calculate time error bewteen tslot with actual usage. + // consider last sleep time error + t1 := NowNS() + terr := c.tSlot<<2 - c.lastSErr - (t1 - c.tSlotT0) + // rest terr/2 if current time usage less than tslot of 100us. + if terr > 1e5 { // 100us + time.Sleep(time.Duration(terr >> 1)) + c.lastSErr = maxI64(NowNS()-t1-terr, 0) + } else { + c.lastSErr >>= 1 + } + } + return nil +} + +func (c *Conn) internalWrite(item *qNode) { + if item.scnt >= 20 { + // no exception of sending fin + if item.flag&_F_FIN != 0 { + c.fakeShutdown() + c.dest = nil + return + } else { + log.Println("Warn: too many retries", item) + if c.urgent > 0 { // abort + c.forceShutdown() + return + } else { // continue to retry 10 + c.urgent++ + item.scnt = 10 + } + } + } + // update current sent time and prev sent time + item.sent, item.sent_1 = Now(), item.sent + item.scnt++ + buf := item.marshall(c.connID) + if debug >= 3 { + var pkType = packetTypeNames[item.flag] + if item.flag&_F_SACK != 0 { + log.Printf("send %s trp=%d on=%d %x", pkType, item.seq, item.ack, buf[_AH_SIZE+4:]) + } else { + log.Printf("send %s seq=%d ack=%d scnt=%d len=%d", pkType, item.seq, item.ack, item.scnt, len(buf)-_TH_SIZE) + } + } + c.sock.WriteToUDP(buf, c.dest) +} + +func (c *Conn) logAck(ack uint32) { + c.lastAck = ack + c.lastAckTime = Now() +} + +func (c *Conn) makeLastAck() (pk *packet) { + c.inlock.Lock() + defer c.inlock.Unlock() + if Now()-c.lastAckTime < c.rtt { + return nil + } + pk = &packet{ + ack: maxU32(c.lastAck, c.inQ.maxCtnSeq), + flag: _F_ACK, + } + c.logAck(pk.ack) + return +} + +func (c *Conn) makeAck(level byte) (pk *packet) { + now := Now() + if level < _VACK_MUST && now-c.lastAckTime < c.ato { + if level < _VACK_QUICK || now-c.lastAckTime < minI64(c.ato>>2, 1) { + return + } + } + // ready Q <-| + // |-> outQ start (or more right) + // |-> bitmap start + // [predecessor] [predecessor+1] [predecessor+2] ..... + var fakeSAck bool + var predecessor = c.inQ.maxCtnSeq + bmap, tbl := c.inQ.makeHolesBitmap(predecessor) + if len(bmap) <= 0 { // fake sack + bmap = make([]uint64, 1) + bmap[0], tbl = 1, 1 + fakeSAck = true + } + // head 4-byte: TBL:1 | SCNT:1 | DELAY:2 + buf := make([]byte, len(bmap)*8+4) + pk = &packet{ + ack: predecessor + 1, + flag: _F_SACK, + payload: buf, + } + if fakeSAck { + pk.ack-- + } + buf[0] = byte(tbl) + // mark delayed time according to the time reference point + if trp := c.inQ.lastIns; trp != nil { + delayed := now - trp.sent + if delayed < c.rtt { + pk.seq = trp.seq + pk.flag |= _F_TIME + buf[1] = trp.scnt + if delayed <= 0 { + delayed = 1 + } + binary.BigEndian.PutUint16(buf[2:], uint16(delayed)) + } + } + buf1 := buf[4:] + for i, b := range bmap { + binary.BigEndian.PutUint64(buf1[i*8:], b) + } + c.logAck(predecessor) + return +} + +func unmarshallSAck(data []byte) (bmap []uint64, tbl uint32, delayed uint16, scnt uint8) { + if len(data) > 0 { + bmap = make([]uint64, len(data)>>3) + } else { + return + } + tbl = uint32(data[0]) + scnt = data[1] + delayed = binary.BigEndian.Uint16(data[2:]) + data = data[4:] + for i := 0; i < len(bmap); i++ { + bmap[i] = binary.BigEndian.Uint64(data[i*8:]) + } + return +} + +func calSwnd(bandwidth, rtt int64) int32 { + w := int32(bandwidth * rtt / (8000 * _MSS)) + if w <= _MAX_SWND { + if w >= _MIN_SWND { + return w + } else { + return _MIN_SWND + } + } else { + return _MAX_SWND + } +} + +func (c *Conn) measure(seq uint32, delayed int64, scnt uint8) { + target := c.outQ.get(seq) + if target != nil { + var lastSent int64 + switch target.scnt - scnt { + case 0: + // not sent again since this ack was sent out + lastSent = target.sent + case 1: + // sent again once since this ack was sent out + // then use prev sent time + lastSent = target.sent_1 + default: + // can't measure here because the packet was sent too many times + return + } + // real-time rtt + rtt := Now() - lastSent - delayed + // reject these abnormal measures: + // 1. rtt too small -> rtt/8 + // 2. backlogging too long + if rtt < maxI64(c.rtt>>3, 1) || delayed > c.rtt>>1 { + return + } + // srtt: update 1/8 + err := rtt - (c.srtt >> 3) + c.srtt += err + c.rtt = c.srtt >> 3 + if c.rtt < _MIN_RTT { + c.rtt = _MIN_RTT + } + // s-swnd: update 1/4 + swnd := c.swnd<<3 - c.swnd + calSwnd(c.bandwidth, c.rtt) + c.swnd = swnd >> 3 + c.tSlot = c.rtt * 1e6 / int64(c.swnd) + c.ato = c.rtt >> 4 + if c.ato < _MIN_ATO { + c.ato = _MIN_ATO + } else if c.ato > _MAX_ATO { + c.ato = _MAX_ATO + } + if err < 0 { + err = -err + err -= c.mdev >> 2 + if err > 0 { + err >>= 3 + } + } else { + err -= c.mdev >> 2 + } + // mdev: update 1/4 + c.mdev += err + rto := c.rtt + maxI64(c.rtt<<1, c.mdev) + if rto >= c.rto { + c.rto = rto + } else { + c.rto = (c.rto + rto) >> 1 + } + if c.rto < _MIN_RTO { + c.rto = _MIN_RTO + } + if debug >= 1 { + log.Printf("--- rtt=%d srtt=%d rto=%d swnd=%d", c.rtt, c.srtt, c.rto, c.swnd) + } + } +} + +func (c *Conn) processSAck(pk *packet) { + c.outlock.Lock() + bmap, tbl, delayed, scnt := unmarshallSAck(pk.payload) + if bmap == nil { // bad packet + c.outlock.Unlock() + return + } + if pk.flag&_F_TIME != 0 { + c.measure(pk.seq, int64(delayed), scnt) + } + deleted, missed, continuous := c.outQ.deleteByBitmap(bmap, pk.ack, tbl) + if deleted > 0 { + c.ackHit(deleted, missed) + // lock is released + } else { + c.outlock.Unlock() + } + if c.fastRetransmit && !continuous { + // peer Q is uncontinuous, then trigger FR + if deleted == 0 { + c.evSWnd <- _VRETR_IMMED + } else { + select { + case c.evSWnd <- _VRETR_IMMED: + default: + } + } + } + if debug >= 2 { + log.Printf("SACK qhead=%d deleted=%d outPending=%d on=%d %016x", + c.outQ.distanceOfHead(0), deleted, c.outPending, pk.ack, bmap) + } +} + +func (c *Conn) processAck(pk *packet) { + c.outlock.Lock() + if end := c.outQ.get(pk.ack); end != nil { // ack hit + _, deleted := c.outQ.deleteBefore(end) + c.ackHit(deleted, 0) // lock is released + if debug >= 2 { + log.Printf("ACK hit on=%d", pk.ack) + } + // special case: ack the FIN + if pk.seq == _FIN_ACK_SEQ { + select { + case c.evClose <- _S_FIN0: + default: + } + } + } else { // duplicated ack + if debug >= 2 { + log.Printf("ACK miss on=%d", pk.ack) + } + if pk.flag&_F_SYN != 0 { // No.3 Ack lost + if pkAck := c.makeLastAck(); pkAck != nil { + c.internalWrite(nodeOf(pkAck)) + } + } + c.outlock.Unlock() + } +} + +func (c *Conn) ackHit(deleted, missed int32) { + // must in outlock + c.outPending -= deleted + now := Now() + if c.cwnd < c.swnd && now-c.lastShrink > c.rto { + if c.cwnd < c.swnd>>1 { + c.cwnd <<= 1 + } else { + c.cwnd += deleted << 1 + } + } + if c.cwnd > c.swnd { + c.cwnd = c.swnd + } + if now-c.lastRstMis > c.ato { + c.lastRstMis = now + c.missed = missed + } else { + c.missed = c.missed>>1 + missed + } + if qswnd := c.swnd >> 4; c.missed > qswnd { + c.missed = qswnd + } + c.outlock.Unlock() + select { + case c.evSend <- 1: + default: + } +} + +func (c *Conn) insertData(pk *packet) { + c.inlock.Lock() + defer c.inlock.Unlock() + exists := c.inQ.contains(pk.seq) + // duplicated with already queued or history + // means: last ACK were lost + if exists || pk.seq <= c.inQ.maxCtnSeq { + // then send ACK for dups + select { + case c.evAck <- _VACK_MUST: + default: + } + if debug >= 2 { + dumpQ(fmt.Sprint("duplicated ", pk.seq), c.inQ) + } + c.inDupCnt++ + return + } + // record current time in sent and regard as received time + item := &qNode{packet: pk, sent: Now()} + dis := c.inQ.searchInsert(item, c.lastReadSeq) + if debug >= 3 { + log.Printf("\t\t\trecv DATA seq=%d dis=%d maxCtn=%d lastReadSeq=%d", item.seq, dis, c.inQ.maxCtnSeq, c.lastReadSeq) + } + + var ackState byte = _VACK_MUST + var available bool + switch dis { + case 0: // impossible + return + case 1: + if c.inQDirty { + available = c.inQ.updateContinuous(item) + if c.inQ.isWholeContinuous() { // whole Q is ordered + c.inQDirty = false + } else { //those holes still exists. + ackState = _VACK_QUICK + } + } else { + // here is an ideal situation + c.inQ.maxCtnSeq = pk.seq + available = true + ackState = _VACK_SCHED + } + + default: // there is an unordered packet, hole occurred here. + if !c.inQDirty { + c.inQDirty = true + } + } + + // write valid received count + c.inPkCnt++ + c.inQ.lastIns = item + // try notify ack + select { + case c.evAck <- ackState: + default: + } + if available { // try notify reader + select { + case c.evRead <- 1: + default: + } + } +} + +func (c *Conn) readInQ() bool { + c.inlock.Lock() + defer c.inlock.Unlock() + // read already <-|-> expected Q + // [lastReadSeq] | [lastReadSeq+1] [lastReadSeq+2] ...... + if c.inQ.isEqualsHead(c.lastReadSeq+1) && c.lastReadSeq < c.inQ.maxCtnSeq { + c.lastReadSeq = c.inQ.maxCtnSeq + availabled := c.inQ.get(c.inQ.maxCtnSeq) + availabled, _ = c.inQ.deleteBefore(availabled) + for i := availabled; i != nil; i = i.next { + c.inQReady = append(c.inQReady, i.payload...) + // data was copied, then could recycle memory + bpool.Put(i.buffer) + i.payload = nil + i.buffer = nil + } + return true + } + return false +} + +// should not call this function concurrently. +func (c *Conn) Read(buf []byte) (nr int, err error) { + for { + if len(c.inQReady) > 0 { + n := copy(buf, c.inQReady) + c.inQReady = c.inQReady[n:] + return n, nil + } + if !c.readInQ() { + if c.rtmo > 0 { + var tmo int64 + tmo, c.rtmo = c.rtmo, 0 + select { + case _, y := <-c.evRead: + if !y && len(c.inQReady) == 0 { + return 0, io.EOF + } + case <-NewTimerChan(tmo): + return 0, ErrIOTimeout + } + } else { + // only when evRead is closed and inQReady is empty + // then could reply eof + if _, y := <-c.evRead; !y && len(c.inQReady) == 0 { + return 0, io.EOF + } + } + } + } +} + +// should not call this function concurrently. +func (c *Conn) Write(data []byte) (nr int, err error) { + for len(data) > 0 && err == nil { + //buf := make([]byte, _MSS+_AH_SIZE) + buf := bpool.Get(c.mss + _AH_SIZE) + body := buf[_TH_SIZE+_CH_SIZE:] + n := copy(body, data) + nr += n + data = data[n:] + pk := &packet{flag: _F_DATA, payload: body[:n], buffer: buf[:_AH_SIZE+n]} + err = c.inputAndSend(pk) + } + return +} + +func (c *Conn) LocalAddr() net.Addr { + return c.sock.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.dest +} + +func (c *Conn) SetDeadline(t time.Time) error { + c.SetReadDeadline(t) + c.SetWriteDeadline(t) + return nil +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + if d := t.UnixNano()/Millisecond - Now(); d > 0 { + c.rtmo = d + } + return nil +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + if d := t.UnixNano()/Millisecond - Now(); d > 0 { + c.wtmo = d + } + return nil +} |
