Imported Upstream version 4.8.1
[platform/upstream/gcc48.git] / libgo / go / net / fd_windows.go
index 45f5c2d..ea6ef10 100644 (file)
@@ -17,19 +17,57 @@ import (
 
 var initErr error
 
-func init() {
+// CancelIo Windows API cancels all outstanding IO for a particular
+// socket on current thread. To overcome that limitation, we run
+// special goroutine, locked to OS single thread, that both starts
+// and cancels IO. It means, there are 2 unavoidable thread switches
+// for every IO.
+// Some newer versions of Windows has new CancelIoEx API, that does
+// not have that limitation and can be used from any thread. This
+// package uses CancelIoEx API, if present, otherwise it fallback
+// to CancelIo.
+
+var canCancelIO bool // determines if CancelIoEx API is present
+
+func sysInit() {
        var d syscall.WSAData
        e := syscall.WSAStartup(uint32(0x202), &d)
        if e != nil {
                initErr = os.NewSyscallError("WSAStartup", e)
        }
+       canCancelIO = syscall.LoadCancelIoEx() == nil
+       if syscall.LoadGetAddrInfo() == nil {
+               lookupIP = newLookupIP
+       }
 }
 
 func closesocket(s syscall.Handle) error {
        return syscall.Closesocket(s)
 }
 
-// Interface for all io operations.
+func canUseConnectEx(net string) bool {
+       if net == "udp" || net == "udp4" || net == "udp6" {
+               // ConnectEx windows API does not support connectionless sockets.
+               return false
+       }
+       return syscall.LoadConnectEx() == nil
+}
+
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+       if !canUseConnectEx(net) {
+               // Use the relatively inefficient goroutine-racing
+               // implementation of DialTimeout.
+               return dialTimeoutRace(net, addr, timeout)
+       }
+       deadline := time.Now().Add(timeout)
+       _, addri, err := resolveNetAddr("dial", net, addr, deadline)
+       if err != nil {
+               return nil, err
+       }
+       return dialAddr(net, addr, addri, deadline)
+}
+
+// Interface for all IO operations.
 type anOpIface interface {
        Op() *anOp
        Name() string
@@ -42,7 +80,7 @@ type ioResult struct {
        err error
 }
 
-// anOp implements functionality common to all io operations.
+// anOp implements functionality common to all IO operations.
 type anOp struct {
        // Used by IOCP interface, it must be first field
        // of the struct, as our code rely on it.
@@ -75,7 +113,7 @@ func (o *anOp) Op() *anOp {
        return o
 }
 
-// bufOp is used by io operations that read / write
+// bufOp is used by IO operations that read / write
 // data from / to client buffer.
 type bufOp struct {
        anOp
@@ -92,7 +130,7 @@ func (o *bufOp) Init(fd *netFD, buf []byte, mode int) {
        }
 }
 
-// resultSrv will retrieve all io completion results from
+// resultSrv will retrieve all IO completion results from
 // iocp and send them to the correspondent waiting client
 // goroutine via channel supplied in the request.
 type resultSrv struct {
@@ -107,7 +145,7 @@ func (s *resultSrv) Run() {
                r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE)
                switch {
                case r.err == nil:
-                       // Dequeued successfully completed io packet.
+                       // Dequeued successfully completed IO packet.
                case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil:
                        // Wait has timed out (should not happen now, but might be used in the future).
                        panic("GetQueuedCompletionStatus timed out")
@@ -115,22 +153,23 @@ func (s *resultSrv) Run() {
                        // Failed to dequeue anything -> report the error.
                        panic("GetQueuedCompletionStatus failed " + r.err.Error())
                default:
-                       // Dequeued failed io packet.
+                       // Dequeued failed IO packet.
                }
                (*anOp)(unsafe.Pointer(o)).resultc <- r
        }
 }
 
-// ioSrv executes net io requests.
+// ioSrv executes net IO requests.
 type ioSrv struct {
-       submchan chan anOpIface // submit io requests
-       canchan  chan anOpIface // cancel io requests
+       submchan chan anOpIface // submit IO requests
+       canchan  chan anOpIface // cancel IO requests
 }
 
-// ProcessRemoteIO will execute submit io requests on behalf
+// ProcessRemoteIO will execute submit IO requests on behalf
 // of other goroutines, all on a single os thread, so it can
 // cancel them later. Results of all operations will be sent
 // back to their requesters via channel supplied in request.
+// It is used only when the CancelIoEx API is unavailable.
 func (s *ioSrv) ProcessRemoteIO() {
        runtime.LockOSThread()
        defer runtime.UnlockOSThread()
@@ -144,20 +183,30 @@ func (s *ioSrv) ProcessRemoteIO() {
        }
 }
 
-// ExecIO executes a single io operation. It either executes it
-// inline, or, if a deadline is employed, passes the request onto
+// ExecIO executes a single IO operation oi. It submits and cancels
+// IO in the current thread for systems where Windows CancelIoEx API
+// is available. Alternatively, it passes the request onto
 // a special goroutine and waits for completion or cancels request.
 // deadline is unix nanos.
 func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) {
        var err error
        o := oi.Op()
+       // Calculate timeout delta.
+       var delta int64
        if deadline != 0 {
+               delta = deadline - time.Now().UnixNano()
+               if delta <= 0 {
+                       return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, errTimeout}
+               }
+       }
+       // Start IO.
+       if canCancelIO {
+               err = oi.Submit()
+       } else {
                // Send request to a special dedicated thread,
-               // so it can stop the io with CancelIO later.
+               // so it can stop the IO with CancelIO later.
                s.submchan <- oi
                err = <-o.errnoc
-       } else {
-               err = oi.Submit()
        }
        switch err {
        case nil:
@@ -168,27 +217,46 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) {
        default:
                return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, err}
        }
+       // Setup timer, if deadline is given.
+       var timer <-chan time.Time
+       if delta > 0 {
+               t := time.NewTimer(time.Duration(delta) * time.Nanosecond)
+               defer t.Stop()
+               timer = t.C
+       }
        // Wait for our request to complete.
        var r ioResult
-       if deadline != 0 {
-               dt := deadline - time.Now().UnixNano()
-               if dt < 1 {
-                       dt = 1
-               }
-               timer := time.NewTimer(time.Duration(dt) * time.Nanosecond)
-               defer timer.Stop()
-               select {
-               case r = <-o.resultc:
-               case <-timer.C:
+       var cancelled, timeout bool
+       select {
+       case r = <-o.resultc:
+       case <-timer:
+               cancelled = true
+               timeout = true
+       case <-o.fd.closec:
+               cancelled = true
+       }
+       if cancelled {
+               // Cancel it.
+               if canCancelIO {
+                       err := syscall.CancelIoEx(syscall.Handle(o.Op().fd.sysfd), &o.o)
+                       // Assuming ERROR_NOT_FOUND is returned, if IO is completed.
+                       if err != nil && err != syscall.ERROR_NOT_FOUND {
+                               // TODO(brainman): maybe do something else, but panic.
+                               panic(err)
+                       }
+               } else {
                        s.canchan <- oi
                        <-o.errnoc
-                       r = <-o.resultc
-                       if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled
-                               r.err = syscall.EWOULDBLOCK
-                       }
                }
-       } else {
+               // Wait for IO to be canceled or complete successfully.
                r = <-o.resultc
+               if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled
+                       if timeout {
+                               r.err = errTimeout
+                       } else {
+                               r.err = errClosing
+                       }
+               }
        }
        if r.err != nil {
                err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err}
@@ -211,9 +279,13 @@ func startServer() {
        go resultsrv.Run()
 
        iosrv = new(ioSrv)
-       iosrv.submchan = make(chan anOpIface)
-       iosrv.canchan = make(chan anOpIface)
-       go iosrv.ProcessRemoteIO()
+       if !canCancelIO {
+               // Only CancelIo API is available. Lets start special goroutine
+               // locked to an OS thread, that both starts and cancels IO.
+               iosrv.submchan = make(chan anOpIface)
+               iosrv.canchan = make(chan anOpIface)
+               go iosrv.ProcessRemoteIO()
+       }
 }
 
 // Network file descriptor.
@@ -233,12 +305,13 @@ type netFD struct {
        raddr       Addr
        resultc     [2]chan ioResult // read/write completion results
        errnoc      [2]chan error    // read/write submit or cancel operation errors
+       closec      chan bool        // used by Close to cancel pending IO
+
+       // serialize access to Read and Write methods
+       rio, wio sync.Mutex
 
-       // owned by client
-       rdeadline int64
-       rio       sync.Mutex
-       wdeadline int64
-       wio       sync.Mutex
+       // read and write deadlines
+       rdeadline, wdeadline deadline
 }
 
 func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD {
@@ -247,8 +320,8 @@ func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD {
                family: family,
                sotype: sotype,
                net:    net,
+               closec: make(chan bool),
        }
-       runtime.SetFinalizer(netfd, (*netFD).Close)
        return netfd
 }
 
@@ -267,13 +340,52 @@ func newFD(fd syscall.Handle, family, proto int, net string) (*netFD, error) {
 func (fd *netFD) setAddr(laddr, raddr Addr) {
        fd.laddr = laddr
        fd.raddr = raddr
+       runtime.SetFinalizer(fd, (*netFD).closesocket)
 }
 
-func (fd *netFD) connect(ra syscall.Sockaddr) error {
-       return syscall.Connect(fd.sysfd, ra)
+// Make new connection.
+
+type connectOp struct {
+       anOp
+       ra syscall.Sockaddr
+}
+
+func (o *connectOp) Submit() error {
+       return syscall.ConnectEx(o.fd.sysfd, o.ra, nil, 0, nil, &o.o)
 }
 
-var errClosing = errors.New("use of closed network connection")
+func (o *connectOp) Name() string {
+       return "ConnectEx"
+}
+
+func (fd *netFD) connect(ra syscall.Sockaddr) error {
+       if !canUseConnectEx(fd.net) {
+               return syscall.Connect(fd.sysfd, ra)
+       }
+       // ConnectEx windows API requires an unconnected, previously bound socket.
+       var la syscall.Sockaddr
+       switch ra.(type) {
+       case *syscall.SockaddrInet4:
+               la = &syscall.SockaddrInet4{}
+       case *syscall.SockaddrInet6:
+               la = &syscall.SockaddrInet6{}
+       default:
+               panic("unexpected type in connect")
+       }
+       if err := syscall.Bind(fd.sysfd, la); err != nil {
+               return err
+       }
+       // Call ConnectEx API.
+       var o connectOp
+       o.Init(fd, 'w')
+       o.ra = ra
+       _, err := iosrv.ExecIO(&o, fd.wdeadline.value())
+       if err != nil {
+               return err
+       }
+       // Refresh socket properties.
+       return syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
+}
 
 // Add a reference to this fd.
 // If closing==true, mark the fd as closing.
@@ -299,24 +411,12 @@ func (fd *netFD) incref(closing bool) error {
 // Remove a reference to this FD and close if we've been asked to do so (and
 // there are no references left.
 func (fd *netFD) decref() {
+       if fd == nil {
+               return
+       }
        fd.sysmu.Lock()
        fd.sysref--
-       // NOTE(rsc): On Unix we check fd.sysref == 0 here before closing,
-       // but on Windows we have no way to wake up the blocked I/O other
-       // than closing the socket (or calling Shutdown, which breaks other
-       // programs that might have a reference to the socket).  So there is
-       // a small race here that we might close fd.sysfd and then some other
-       // goroutine might start a read of fd.sysfd (having read it before we
-       // write InvalidHandle to it), which might refer to some other file
-       // if the specific handle value gets reused.  I think handle values on
-       // Windows are not reused as aggressively as file descriptors on Unix,
-       // so this might be tolerable.
-       if fd.closing && fd.sysfd != syscall.InvalidHandle {
-               // In case the user has set linger, switch to blocking mode so
-               // the close blocks.  As long as this doesn't happen often, we
-               // can handle the extra OS processes.  Otherwise we'll need to
-               // use the resultsrv for Close too.  Sigh.
-               syscall.SetNonblock(fd.sysfd, false)
+       if fd.closing && fd.sysref == 0 && fd.sysfd != syscall.InvalidHandle {
                closesocket(fd.sysfd)
                fd.sysfd = syscall.InvalidHandle
                // no need for a finalizer anymore
@@ -329,14 +429,22 @@ func (fd *netFD) Close() error {
        if err := fd.incref(true); err != nil {
                return err
        }
-       fd.decref()
+       defer fd.decref()
+       // unblock pending reader and writer
+       close(fd.closec)
+       // wait for both reader and writer to exit
+       fd.rio.Lock()
+       defer fd.rio.Unlock()
+       fd.wio.Lock()
+       defer fd.wio.Unlock()
        return nil
 }
 
 func (fd *netFD) shutdown(how int) error {
-       if fd == nil || fd.sysfd == syscall.InvalidHandle {
-               return syscall.EINVAL
+       if err := fd.incref(false); err != nil {
+               return err
        }
+       defer fd.decref()
        err := syscall.Shutdown(fd.sysfd, how)
        if err != nil {
                return &OpError{"shutdown", fd.net, fd.laddr, err}
@@ -352,6 +460,10 @@ func (fd *netFD) CloseWrite() error {
        return fd.shutdown(syscall.SHUT_WR)
 }
 
+func (fd *netFD) closesocket() error {
+       return closesocket(fd.sysfd)
+}
+
 // Read from network.
 
 type readOp struct {
@@ -368,21 +480,15 @@ func (o *readOp) Name() string {
 }
 
 func (fd *netFD) Read(buf []byte) (int, error) {
-       if fd == nil {
-               return 0, syscall.EINVAL
-       }
-       fd.rio.Lock()
-       defer fd.rio.Unlock()
        if err := fd.incref(false); err != nil {
                return 0, err
        }
        defer fd.decref()
-       if fd.sysfd == syscall.InvalidHandle {
-               return 0, syscall.EINVAL
-       }
+       fd.rio.Lock()
+       defer fd.rio.Unlock()
        var o readOp
        o.Init(fd, buf, 'r')
-       n, err := iosrv.ExecIO(&o, fd.rdeadline)
+       n, err := iosrv.ExecIO(&o, fd.rdeadline.value())
        if err == nil && n == 0 {
                err = io.EOF
        }
@@ -407,22 +513,19 @@ func (o *readFromOp) Name() string {
 }
 
 func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) {
-       if fd == nil {
-               return 0, nil, syscall.EINVAL
-       }
        if len(buf) == 0 {
                return 0, nil, nil
        }
-       fd.rio.Lock()
-       defer fd.rio.Unlock()
        if err := fd.incref(false); err != nil {
                return 0, nil, err
        }
        defer fd.decref()
+       fd.rio.Lock()
+       defer fd.rio.Unlock()
        var o readFromOp
        o.Init(fd, buf, 'r')
        o.rsan = int32(unsafe.Sizeof(o.rsa))
-       n, err = iosrv.ExecIO(&o, fd.rdeadline)
+       n, err = iosrv.ExecIO(&o, fd.rdeadline.value())
        if err != nil {
                return 0, nil, err
        }
@@ -446,18 +549,15 @@ func (o *writeOp) Name() string {
 }
 
 func (fd *netFD) Write(buf []byte) (int, error) {
-       if fd == nil {
-               return 0, syscall.EINVAL
-       }
-       fd.wio.Lock()
-       defer fd.wio.Unlock()
        if err := fd.incref(false); err != nil {
                return 0, err
        }
        defer fd.decref()
+       fd.wio.Lock()
+       defer fd.wio.Unlock()
        var o writeOp
        o.Init(fd, buf, 'w')
-       return iosrv.ExecIO(&o, fd.wdeadline)
+       return iosrv.ExecIO(&o, fd.wdeadline.value())
 }
 
 // WriteTo to network.
@@ -477,25 +577,19 @@ func (o *writeToOp) Name() string {
 }
 
 func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (int, error) {
-       if fd == nil {
-               return 0, syscall.EINVAL
-       }
        if len(buf) == 0 {
                return 0, nil
        }
-       fd.wio.Lock()
-       defer fd.wio.Unlock()
        if err := fd.incref(false); err != nil {
                return 0, err
        }
        defer fd.decref()
-       if fd.sysfd == syscall.InvalidHandle {
-               return 0, syscall.EINVAL
-       }
+       fd.wio.Lock()
+       defer fd.wio.Unlock()
        var o writeToOp
        o.Init(fd, buf, 'w')
        o.sa = sa
-       return iosrv.ExecIO(&o, fd.wdeadline)
+       return iosrv.ExecIO(&o, fd.wdeadline.value())
 }
 
 // Accept new network connections.
@@ -524,12 +618,12 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) {
        defer fd.decref()
 
        // Get new socket.
-       // See ../syscall/exec.go for description of ForkLock.
+       // See ../syscall/exec_unix.go for description of ForkLock.
        syscall.ForkLock.RLock()
        s, err := syscall.Socket(fd.family, fd.sotype, 0)
        if err != nil {
                syscall.ForkLock.RUnlock()
-               return nil, err
+               return nil, &OpError{"socket", fd.net, fd.laddr, err}
        }
        syscall.CloseOnExec(s)
        syscall.ForkLock.RUnlock()
@@ -537,6 +631,7 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) {
        // Associate our new socket with IOCP.
        onceStartServer.Do(startServer)
        if _, err := syscall.CreateIoCompletionPort(s, resultsrv.iocp, 0, 0); err != nil {
+               closesocket(s)
                return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, err}
        }
 
@@ -544,7 +639,7 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) {
        var o acceptOp
        o.Init(fd, 'r')
        o.newsock = s
-       _, err = iosrv.ExecIO(&o, 0)
+       _, err = iosrv.ExecIO(&o, fd.rdeadline.value())
        if err != nil {
                closesocket(s)
                return nil, err
@@ -554,7 +649,7 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) {
        err = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
        if err != nil {
                closesocket(s)
-               return nil, err
+               return nil, &OpError{"Setsockopt", fd.net, fd.laddr, err}
        }
 
        // Get local and peer addr out of AcceptEx buffer.