diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/TunnelActor.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/TunnelActor.kt index a90827c7..7a836ea7 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/TunnelActor.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/TunnelActor.kt @@ -63,6 +63,7 @@ internal class TunnelActor( ) { private val inbox = Channel(Channel.UNLIMITED) + private val stopping = mutableSetOf() // track running hooks to prevent service shutdown until post down hooks complete private val _runningPostDownHooks = MutableStateFlow(0) @@ -101,27 +102,45 @@ internal class TunnelActor( try { when (cmd) { is TunnelCommand.Start -> { - val result = engine.start(cmd.tunnel, cmd.mode) - apply(TunnelStarted(result, cmd)) - - val runtime = _state.value.byTunnelId[result.tunnelId] ?: continue - - val job = - startTunnelJobs( - tunnelId = result.tunnelId, - runtime = runtime, - removedPeerEndpoint = result.removedPeerEndpoint, + if (stopping.contains(cmd.tunnel.id)) { + Timber.d( + "Tunnel ${cmd.tunnel.id} is still stopping, ignoring rapid start" ) + continue + } - tunnelJobs[result.tunnelId] = job + try { + val result = engine.start(cmd.tunnel, cmd.mode) + stopping -= cmd.tunnel.id + apply(TunnelStarted(result, cmd)) - job.invokeOnCompletion { tunnelJobs.remove(result.tunnelId, job) } + val runtime = _state.value.byTunnelId[result.tunnelId] ?: continue + + val job = + startTunnelJobs( + tunnelId = result.tunnelId, + runtime = runtime, + removedPeerEndpoint = result.removedPeerEndpoint, + ) + + tunnelJobs[result.tunnelId] = job + + job.invokeOnCompletion { tunnelJobs.remove(result.tunnelId, job) } + } finally { + stopping -= cmd.tunnel.id + } } is TunnelCommand.Stop -> { val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue - + stopping += cmd.tunnelId engine.stop(runtime.running.handle, runtime.running.mode) + + // Extra safety delay for static listen ports to allow OS time to + // release port + if (runtime.running.mode.config.`interface`.listenPort != null) { + delay(250.milliseconds) + } } is TunnelCommand.SetBootstrapConfig -> { @@ -282,6 +301,7 @@ internal class TunnelActor( } private fun stopTunnel(tunnelId: Int, handle: Int) { + stopping -= tunnelId tunnelJobs.remove(tunnelId)?.cancel() apply(TunnelStopped(tunnelId, handle)) } @@ -689,6 +709,7 @@ internal class TunnelActor( } fun emergencyStop(tunnelId: Int) { + stopping -= tunnelId val runtime = _state.value.byTunnelId[tunnelId] ?: return val handle = runtime.running.handle val mode = runtime.running.mode diff --git a/tunnel/tools/libwg-go/proxy/proxy.go b/tunnel/tools/libwg-go/proxy/proxy.go index 248e0944..57ede3b9 100644 --- a/tunnel/tools/libwg-go/proxy/proxy.go +++ b/tunnel/tools/libwg-go/proxy/proxy.go @@ -35,7 +35,6 @@ func init() { //export awgStartProxy func awgStartProxy(interfaceName string, config string, uapiPath string, bypass int32) int32 { - conf, err := wireproxyawg.ParseConfigString(config) if err != nil { shared.LogError(tag, "Invalid config file", err) @@ -49,22 +48,29 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass } setting, err := wireproxyawg.CreateIPCRequest(conf.Device, false) - if err != nil { shared.LogError(tag, "Create IPC request failed", err) + shared.ReleaseHandle(handle) return -1 } tun, tnet, err := netstack.CreateNetTUN(setting.DeviceAddr, setting.DNS, setting.MTU) if err != nil { shared.LogError(tag, "Create TUN failed", err) + shared.ReleaseHandle(handle) return -1 } name, err := tun.Name() + if err != nil { + shared.LogError(tag, "Failed to get TUN name: %v", err) + shared.ReleaseHandle(handle) + tun.Close() + return -1 + } + // Lockdown modes needs our socket protection var bind conn.Bind - if bypass == 1 { bind = conn.NewStdNetBindWithControl(protectControlFunc) } else { @@ -76,27 +82,28 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass } dev := device.NewDevice(tun, bind, shared.NewLogger("Tun/"+interfaceName), statusCB) - dev.DisableSomeRoamingForBrokenMobileSemantics() err = dev.IpcSet(setting.IpcRequest) - if err != nil { shared.LogError(tag, "Ipc setting failed", err) + shared.ReleaseHandle(handle) + dev.Close() return -1 } - uapiFile, err := ipc.UAPIOpen(uapiPath, name) - var uapi net.Listener - - if err != nil { - shared.LogError(tag, "UAPIOpen: %v", err) + uapiFile, uapiErr := ipc.UAPIOpen(uapiPath, name) + if uapiErr != nil { + shared.LogError(tag, "UAPIOpen: %v", uapiErr) + uapiFile = nil } else { uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile) if err != nil { - uapiFile.Close() shared.LogError(tag, "UAPIListen: %v", err) + uapiFile.Close() + uapiFile = nil + uapi = nil } else { go func() { for { @@ -113,7 +120,13 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass err = dev.Up() if err != nil { shared.LogError(tag, "Failed to bring up device", err) - uapiFile.Close() + if uapiFile != nil { + uapiFile.Close() + } + if uapi != nil { + uapi.Close() + } + shared.ReleaseHandle(handle) dev.Close() return -1 } @@ -127,16 +140,12 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass PingRecord: make(map[string]uint64), PingRecordLock: new(sync.Mutex), } - virtualTunnelHandles[handle] = virtualTun - // Create cancellable context tunnelCtx, tunnelCancel := context.WithCancel(context.Background()) cancelFuncs[handle] = tunnelCancel - // Spawn all routines with context for _, spawner := range conf.Routines { - shared.LogDebug(tag, "Spawning routine..") go func(s wireproxyawg.RoutineSpawner) { if err := s.SpawnRoutine(tunnelCtx, virtualTun); err != nil { shared.LogError(tag, "Routine failed: %v", err) @@ -144,7 +153,7 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass }(spawner) } - shared.LogDebug(tag, "Done starting proxy and tunnel") + shared.LogDebug(tag, "Done starting proxy and tunnel for handle %d", handle) return handle } @@ -219,27 +228,29 @@ func awgTurnProxyTunnelOff(virtualTunnelHandle int32) { } shared.LogDebug(tag, "Tearing down tunnel %d", virtualTunnelHandle) + delete(virtualTunnelHandles, virtualTunnelHandle) + if cancel, exists := cancelFuncs[virtualTunnelHandle]; exists { cancel() delete(cancelFuncs, virtualTunnelHandle) - time.Sleep(50 * time.Millisecond) } - // Close UAPI listener and underlying file if virtualTun.Uapi != nil { virtualTun.Uapi.Close() } - if virtualTun.Dev != nil { virtualTun.Dev.Close() } - go C.awgNotifyStatus( + C.awgNotifyStatus( C.int32_t(virtualTunnelHandle), C.int32_t(shared.StatusStop), ) - delete(virtualTunnelHandles, virtualTunnelHandle) + // Give time for full resource release + time.Sleep(150 * time.Millisecond) + shared.ReleaseHandle(virtualTunnelHandle) - shared.LogDebug(tag, "Tunnel %d fully closed (UAPI/Dev/Bind purged)", virtualTunnelHandle) + + shared.LogDebug(tag, "Tunnel handle %d fully closed", virtualTunnelHandle) } diff --git a/tunnel/tools/libwg-go/vpn/vpn.go b/tunnel/tools/libwg-go/vpn/vpn.go index 076a87f9..4afacf2e 100644 --- a/tunnel/tools/libwg-go/vpn/vpn.go +++ b/tunnel/tools/libwg-go/vpn/vpn.go @@ -13,6 +13,7 @@ import ( "net" "runtime/debug" "strings" + "time" "github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/device" @@ -40,7 +41,6 @@ func init() { //export awgTurnOn func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath string) int32 { tunnel, name, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd)) - if err != nil { unix.Close(int(tunFd)) shared.LogError(tag, "CreateUnmonitoredTUNFromFD: %v", err) @@ -50,52 +50,56 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri conf, err := wireproxyawg.ParseConfigString(settings) if err != nil { shared.LogError(tag, "Invalid config file", err) - unix.Close(int(tunFd)) if tunnel != nil { tunnel.Close() } return -1 } - shared.LogDebug(tag, "Creating device with domain blocking enabled: %v", conf.Device.DomainBlockingEnabled) - - handle, err2 := shared.GenerateUniqueHandle() + handle, err := shared.GenerateUniqueHandle() + if err != nil { + shared.LogError(tag, "Unable to generate handle: %v", err) + if tunnel != nil { + tunnel.Close() + } + return -1 + } statusCB := func(code device.StatusCode) { go C.awgNotifyStatus(C.int32_t(handle), C.int32_t(code)) } tunDevice := device.NewDevice(tunnel, conn.NewStdNetBind(), shared.NewLogger("Tun/"+interfaceName), statusCB) - tunDevice.DisableSomeRoamingForBrokenMobileSemantics() ipcRequest, err := wireproxyawg.CreateIPCRequest(conf.Device, false) if err != nil { shared.LogError(tag, "CreateIPCRequest: %v", err) - unix.Close(int(tunFd)) shared.ReleaseHandle(handle) + tunDevice.Close() return -1 } err = tunDevice.IpcSet(ipcRequest.IpcRequest) if err != nil { - unix.Close(int(tunFd)) - shared.ReleaseHandle(handle) shared.LogError(tag, "IpcSet: %v", err) + shared.ReleaseHandle(handle) + tunDevice.Close() return -1 } var uapi net.Listener - - uapiFile, err := ipc.UAPIOpen(uapiPath, name) - - if err != nil { - shared.LogError(tag, "UAPIOpen: %v", err) + uapiFile, uapiErr := ipc.UAPIOpen(uapiPath, name) + if uapiErr != nil { + shared.LogError(tag, "UAPIOpen: %v", uapiErr) + uapiFile = nil } else { - uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile) // uapiPath as rootdir, name as interface + uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile) if err != nil { - uapiFile.Close() shared.LogError(tag, "UAPIListen: %v", err) + uapiFile.Close() + uapiFile = nil + uapi = nil } else { go func() { for { @@ -112,23 +116,20 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri err = tunDevice.Up() if err != nil { shared.LogError(tag, "Unable to bring up device: %v", err) - uapiFile.Close() + if uapiFile != nil { + uapiFile.Close() + } + if uapi != nil { + uapi.Close() + } shared.ReleaseHandle(handle) tunDevice.Close() return -1 } - shared.LogDebug(tag, "Device started") - if err2 != nil { - shared.LogError(tag, "Unable to find empty handle", err2) - uapiFile.Close() - shared.ReleaseHandle(handle) - tunDevice.Close() - return -1 - } + shared.LogDebug(tag, "Tunnel started successfully for handle %d", handle) tunnelHandles[handle] = TunnelHandle{device: tunDevice, uapi: uapi} - return handle } @@ -158,7 +159,7 @@ func awgUpdateTunnelPeers(tunnelHandle int32, settings string) int32 { return -1 } - shared.LogDebug(tag, "Configuration updated successfully") + shared.LogDebug(tag, "Configuration updated successfully with handle %d", handle) return 0 } @@ -170,16 +171,20 @@ func awgTurnOff(tunnelHandle int32) { return } - go C.awgNotifyStatus( - C.int32_t(tunnelHandle), - C.int32_t(shared.StatusStop), - ) - delete(tunnelHandles, tunnelHandle) + if handle.uapi != nil { handle.uapi.Close() } - handle.device.Close() + if handle.device != nil { + handle.device.Close() + } + + C.awgNotifyStatus(C.int32_t(tunnelHandle), C.int32_t(shared.StatusStop)) + + // Give time for full resource release + time.Sleep(150 * time.Millisecond) + shared.ReleaseHandle(tunnelHandle) }