fix: tunnel start/stop race on fast toggles

This commit is contained in:
zaneschepke
2026-06-05 04:24:34 -04:00
parent 5bc49eec50
commit abdbf74755
3 changed files with 106 additions and 69 deletions
@@ -63,6 +63,7 @@ internal class TunnelActor(
) {
private val inbox = Channel<TunnelCommand>(Channel.UNLIMITED)
private val stopping = mutableSetOf<Int>()
// track running hooks to prevent service shutdown until post down hooks complete
private val _runningPostDownHooks = MutableStateFlow(0)
@@ -101,27 +102,45 @@ internal class TunnelActor(
try {
when (cmd) {
is TunnelCommand.Start -> {
val result = engine.start(cmd.tunnel, cmd.mode)
apply(TunnelStarted(result, cmd))
val runtime = _state.value.byTunnelId[result.tunnelId] ?: continue
val job =
startTunnelJobs(
tunnelId = result.tunnelId,
runtime = runtime,
removedPeerEndpoint = result.removedPeerEndpoint,
if (stopping.contains(cmd.tunnel.id)) {
Timber.d(
"Tunnel ${cmd.tunnel.id} is still stopping, ignoring rapid start"
)
continue
}
tunnelJobs[result.tunnelId] = job
try {
val result = engine.start(cmd.tunnel, cmd.mode)
stopping -= cmd.tunnel.id
apply(TunnelStarted(result, cmd))
job.invokeOnCompletion { tunnelJobs.remove(result.tunnelId, job) }
val runtime = _state.value.byTunnelId[result.tunnelId] ?: continue
val job =
startTunnelJobs(
tunnelId = result.tunnelId,
runtime = runtime,
removedPeerEndpoint = result.removedPeerEndpoint,
)
tunnelJobs[result.tunnelId] = job
job.invokeOnCompletion { tunnelJobs.remove(result.tunnelId, job) }
} finally {
stopping -= cmd.tunnel.id
}
}
is TunnelCommand.Stop -> {
val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue
stopping += cmd.tunnelId
engine.stop(runtime.running.handle, runtime.running.mode)
// Extra safety delay for static listen ports to allow OS time to
// release port
if (runtime.running.mode.config.`interface`.listenPort != null) {
delay(250.milliseconds)
}
}
is TunnelCommand.SetBootstrapConfig -> {
@@ -282,6 +301,7 @@ internal class TunnelActor(
}
private fun stopTunnel(tunnelId: Int, handle: Int) {
stopping -= tunnelId
tunnelJobs.remove(tunnelId)?.cancel()
apply(TunnelStopped(tunnelId, handle))
}
@@ -689,6 +709,7 @@ internal class TunnelActor(
}
fun emergencyStop(tunnelId: Int) {
stopping -= tunnelId
val runtime = _state.value.byTunnelId[tunnelId] ?: return
val handle = runtime.running.handle
val mode = runtime.running.mode
+34 -23
View File
@@ -35,7 +35,6 @@ func init() {
//export awgStartProxy
func awgStartProxy(interfaceName string, config string, uapiPath string, bypass int32) int32 {
conf, err := wireproxyawg.ParseConfigString(config)
if err != nil {
shared.LogError(tag, "Invalid config file", err)
@@ -49,22 +48,29 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
}
setting, err := wireproxyawg.CreateIPCRequest(conf.Device, false)
if err != nil {
shared.LogError(tag, "Create IPC request failed", err)
shared.ReleaseHandle(handle)
return -1
}
tun, tnet, err := netstack.CreateNetTUN(setting.DeviceAddr, setting.DNS, setting.MTU)
if err != nil {
shared.LogError(tag, "Create TUN failed", err)
shared.ReleaseHandle(handle)
return -1
}
name, err := tun.Name()
if err != nil {
shared.LogError(tag, "Failed to get TUN name: %v", err)
shared.ReleaseHandle(handle)
tun.Close()
return -1
}
// Lockdown modes needs our socket protection
var bind conn.Bind
if bypass == 1 {
bind = conn.NewStdNetBindWithControl(protectControlFunc)
} else {
@@ -76,27 +82,28 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
}
dev := device.NewDevice(tun, bind, shared.NewLogger("Tun/"+interfaceName), statusCB)
dev.DisableSomeRoamingForBrokenMobileSemantics()
err = dev.IpcSet(setting.IpcRequest)
if err != nil {
shared.LogError(tag, "Ipc setting failed", err)
shared.ReleaseHandle(handle)
dev.Close()
return -1
}
uapiFile, err := ipc.UAPIOpen(uapiPath, name)
var uapi net.Listener
if err != nil {
shared.LogError(tag, "UAPIOpen: %v", err)
uapiFile, uapiErr := ipc.UAPIOpen(uapiPath, name)
if uapiErr != nil {
shared.LogError(tag, "UAPIOpen: %v", uapiErr)
uapiFile = nil
} else {
uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile)
if err != nil {
uapiFile.Close()
shared.LogError(tag, "UAPIListen: %v", err)
uapiFile.Close()
uapiFile = nil
uapi = nil
} else {
go func() {
for {
@@ -113,7 +120,13 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
err = dev.Up()
if err != nil {
shared.LogError(tag, "Failed to bring up device", err)
uapiFile.Close()
if uapiFile != nil {
uapiFile.Close()
}
if uapi != nil {
uapi.Close()
}
shared.ReleaseHandle(handle)
dev.Close()
return -1
}
@@ -127,16 +140,12 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
PingRecord: make(map[string]uint64),
PingRecordLock: new(sync.Mutex),
}
virtualTunnelHandles[handle] = virtualTun
// Create cancellable context
tunnelCtx, tunnelCancel := context.WithCancel(context.Background())
cancelFuncs[handle] = tunnelCancel
// Spawn all routines with context
for _, spawner := range conf.Routines {
shared.LogDebug(tag, "Spawning routine..")
go func(s wireproxyawg.RoutineSpawner) {
if err := s.SpawnRoutine(tunnelCtx, virtualTun); err != nil {
shared.LogError(tag, "Routine failed: %v", err)
@@ -144,7 +153,7 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
}(spawner)
}
shared.LogDebug(tag, "Done starting proxy and tunnel")
shared.LogDebug(tag, "Done starting proxy and tunnel for handle %d", handle)
return handle
}
@@ -219,27 +228,29 @@ func awgTurnProxyTunnelOff(virtualTunnelHandle int32) {
}
shared.LogDebug(tag, "Tearing down tunnel %d", virtualTunnelHandle)
delete(virtualTunnelHandles, virtualTunnelHandle)
if cancel, exists := cancelFuncs[virtualTunnelHandle]; exists {
cancel()
delete(cancelFuncs, virtualTunnelHandle)
time.Sleep(50 * time.Millisecond)
}
// Close UAPI listener and underlying file
if virtualTun.Uapi != nil {
virtualTun.Uapi.Close()
}
if virtualTun.Dev != nil {
virtualTun.Dev.Close()
}
go C.awgNotifyStatus(
C.awgNotifyStatus(
C.int32_t(virtualTunnelHandle),
C.int32_t(shared.StatusStop),
)
delete(virtualTunnelHandles, virtualTunnelHandle)
// Give time for full resource release
time.Sleep(150 * time.Millisecond)
shared.ReleaseHandle(virtualTunnelHandle)
shared.LogDebug(tag, "Tunnel %d fully closed (UAPI/Dev/Bind purged)", virtualTunnelHandle)
shared.LogDebug(tag, "Tunnel handle %d fully closed", virtualTunnelHandle)
}
+38 -33
View File
@@ -13,6 +13,7 @@ import (
"net"
"runtime/debug"
"strings"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
@@ -40,7 +41,6 @@ func init() {
//export awgTurnOn
func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath string) int32 {
tunnel, name, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
if err != nil {
unix.Close(int(tunFd))
shared.LogError(tag, "CreateUnmonitoredTUNFromFD: %v", err)
@@ -50,52 +50,56 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri
conf, err := wireproxyawg.ParseConfigString(settings)
if err != nil {
shared.LogError(tag, "Invalid config file", err)
unix.Close(int(tunFd))
if tunnel != nil {
tunnel.Close()
}
return -1
}
shared.LogDebug(tag, "Creating device with domain blocking enabled: %v", conf.Device.DomainBlockingEnabled)
handle, err2 := shared.GenerateUniqueHandle()
handle, err := shared.GenerateUniqueHandle()
if err != nil {
shared.LogError(tag, "Unable to generate handle: %v", err)
if tunnel != nil {
tunnel.Close()
}
return -1
}
statusCB := func(code device.StatusCode) {
go C.awgNotifyStatus(C.int32_t(handle), C.int32_t(code))
}
tunDevice := device.NewDevice(tunnel, conn.NewStdNetBind(), shared.NewLogger("Tun/"+interfaceName), statusCB)
tunDevice.DisableSomeRoamingForBrokenMobileSemantics()
ipcRequest, err := wireproxyawg.CreateIPCRequest(conf.Device, false)
if err != nil {
shared.LogError(tag, "CreateIPCRequest: %v", err)
unix.Close(int(tunFd))
shared.ReleaseHandle(handle)
tunDevice.Close()
return -1
}
err = tunDevice.IpcSet(ipcRequest.IpcRequest)
if err != nil {
unix.Close(int(tunFd))
shared.ReleaseHandle(handle)
shared.LogError(tag, "IpcSet: %v", err)
shared.ReleaseHandle(handle)
tunDevice.Close()
return -1
}
var uapi net.Listener
uapiFile, err := ipc.UAPIOpen(uapiPath, name)
if err != nil {
shared.LogError(tag, "UAPIOpen: %v", err)
uapiFile, uapiErr := ipc.UAPIOpen(uapiPath, name)
if uapiErr != nil {
shared.LogError(tag, "UAPIOpen: %v", uapiErr)
uapiFile = nil
} else {
uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile) // uapiPath as rootdir, name as interface
uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile)
if err != nil {
uapiFile.Close()
shared.LogError(tag, "UAPIListen: %v", err)
uapiFile.Close()
uapiFile = nil
uapi = nil
} else {
go func() {
for {
@@ -112,23 +116,20 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri
err = tunDevice.Up()
if err != nil {
shared.LogError(tag, "Unable to bring up device: %v", err)
uapiFile.Close()
if uapiFile != nil {
uapiFile.Close()
}
if uapi != nil {
uapi.Close()
}
shared.ReleaseHandle(handle)
tunDevice.Close()
return -1
}
shared.LogDebug(tag, "Device started")
if err2 != nil {
shared.LogError(tag, "Unable to find empty handle", err2)
uapiFile.Close()
shared.ReleaseHandle(handle)
tunDevice.Close()
return -1
}
shared.LogDebug(tag, "Tunnel started successfully for handle %d", handle)
tunnelHandles[handle] = TunnelHandle{device: tunDevice, uapi: uapi}
return handle
}
@@ -158,7 +159,7 @@ func awgUpdateTunnelPeers(tunnelHandle int32, settings string) int32 {
return -1
}
shared.LogDebug(tag, "Configuration updated successfully")
shared.LogDebug(tag, "Configuration updated successfully with handle %d", handle)
return 0
}
@@ -170,16 +171,20 @@ func awgTurnOff(tunnelHandle int32) {
return
}
go C.awgNotifyStatus(
C.int32_t(tunnelHandle),
C.int32_t(shared.StatusStop),
)
delete(tunnelHandles, tunnelHandle)
if handle.uapi != nil {
handle.uapi.Close()
}
handle.device.Close()
if handle.device != nil {
handle.device.Close()
}
C.awgNotifyStatus(C.int32_t(tunnelHandle), C.int32_t(shared.StatusStop))
// Give time for full resource release
time.Sleep(150 * time.Millisecond)
shared.ReleaseHandle(tunnelHandle)
}