mirror of
https://github.com/wgtunnel/android.git
synced 2026-07-03 14:07:49 +02:00
fix: tunnel start/stop race on fast toggles
This commit is contained in:
@@ -63,6 +63,7 @@ internal class TunnelActor(
|
||||
) {
|
||||
|
||||
private val inbox = Channel<TunnelCommand>(Channel.UNLIMITED)
|
||||
private val stopping = mutableSetOf<Int>()
|
||||
|
||||
// track running hooks to prevent service shutdown until post down hooks complete
|
||||
private val _runningPostDownHooks = MutableStateFlow(0)
|
||||
@@ -101,27 +102,45 @@ internal class TunnelActor(
|
||||
try {
|
||||
when (cmd) {
|
||||
is TunnelCommand.Start -> {
|
||||
val result = engine.start(cmd.tunnel, cmd.mode)
|
||||
apply(TunnelStarted(result, cmd))
|
||||
|
||||
val runtime = _state.value.byTunnelId[result.tunnelId] ?: continue
|
||||
|
||||
val job =
|
||||
startTunnelJobs(
|
||||
tunnelId = result.tunnelId,
|
||||
runtime = runtime,
|
||||
removedPeerEndpoint = result.removedPeerEndpoint,
|
||||
if (stopping.contains(cmd.tunnel.id)) {
|
||||
Timber.d(
|
||||
"Tunnel ${cmd.tunnel.id} is still stopping, ignoring rapid start"
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
tunnelJobs[result.tunnelId] = job
|
||||
try {
|
||||
val result = engine.start(cmd.tunnel, cmd.mode)
|
||||
stopping -= cmd.tunnel.id
|
||||
apply(TunnelStarted(result, cmd))
|
||||
|
||||
job.invokeOnCompletion { tunnelJobs.remove(result.tunnelId, job) }
|
||||
val runtime = _state.value.byTunnelId[result.tunnelId] ?: continue
|
||||
|
||||
val job =
|
||||
startTunnelJobs(
|
||||
tunnelId = result.tunnelId,
|
||||
runtime = runtime,
|
||||
removedPeerEndpoint = result.removedPeerEndpoint,
|
||||
)
|
||||
|
||||
tunnelJobs[result.tunnelId] = job
|
||||
|
||||
job.invokeOnCompletion { tunnelJobs.remove(result.tunnelId, job) }
|
||||
} finally {
|
||||
stopping -= cmd.tunnel.id
|
||||
}
|
||||
}
|
||||
|
||||
is TunnelCommand.Stop -> {
|
||||
val runtime = _state.value.byTunnelId[cmd.tunnelId] ?: continue
|
||||
|
||||
stopping += cmd.tunnelId
|
||||
engine.stop(runtime.running.handle, runtime.running.mode)
|
||||
|
||||
// Extra safety delay for static listen ports to allow OS time to
|
||||
// release port
|
||||
if (runtime.running.mode.config.`interface`.listenPort != null) {
|
||||
delay(250.milliseconds)
|
||||
}
|
||||
}
|
||||
|
||||
is TunnelCommand.SetBootstrapConfig -> {
|
||||
@@ -282,6 +301,7 @@ internal class TunnelActor(
|
||||
}
|
||||
|
||||
private fun stopTunnel(tunnelId: Int, handle: Int) {
|
||||
stopping -= tunnelId
|
||||
tunnelJobs.remove(tunnelId)?.cancel()
|
||||
apply(TunnelStopped(tunnelId, handle))
|
||||
}
|
||||
@@ -689,6 +709,7 @@ internal class TunnelActor(
|
||||
}
|
||||
|
||||
fun emergencyStop(tunnelId: Int) {
|
||||
stopping -= tunnelId
|
||||
val runtime = _state.value.byTunnelId[tunnelId] ?: return
|
||||
val handle = runtime.running.handle
|
||||
val mode = runtime.running.mode
|
||||
|
||||
@@ -35,7 +35,6 @@ func init() {
|
||||
|
||||
//export awgStartProxy
|
||||
func awgStartProxy(interfaceName string, config string, uapiPath string, bypass int32) int32 {
|
||||
|
||||
conf, err := wireproxyawg.ParseConfigString(config)
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Invalid config file", err)
|
||||
@@ -49,22 +48,29 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
|
||||
}
|
||||
|
||||
setting, err := wireproxyawg.CreateIPCRequest(conf.Device, false)
|
||||
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Create IPC request failed", err)
|
||||
shared.ReleaseHandle(handle)
|
||||
return -1
|
||||
}
|
||||
|
||||
tun, tnet, err := netstack.CreateNetTUN(setting.DeviceAddr, setting.DNS, setting.MTU)
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Create TUN failed", err)
|
||||
shared.ReleaseHandle(handle)
|
||||
return -1
|
||||
}
|
||||
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Failed to get TUN name: %v", err)
|
||||
shared.ReleaseHandle(handle)
|
||||
tun.Close()
|
||||
return -1
|
||||
}
|
||||
|
||||
// Lockdown modes needs our socket protection
|
||||
var bind conn.Bind
|
||||
|
||||
if bypass == 1 {
|
||||
bind = conn.NewStdNetBindWithControl(protectControlFunc)
|
||||
} else {
|
||||
@@ -76,27 +82,28 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
|
||||
}
|
||||
|
||||
dev := device.NewDevice(tun, bind, shared.NewLogger("Tun/"+interfaceName), statusCB)
|
||||
|
||||
dev.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
err = dev.IpcSet(setting.IpcRequest)
|
||||
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Ipc setting failed", err)
|
||||
shared.ReleaseHandle(handle)
|
||||
dev.Close()
|
||||
return -1
|
||||
}
|
||||
|
||||
uapiFile, err := ipc.UAPIOpen(uapiPath, name)
|
||||
|
||||
var uapi net.Listener
|
||||
|
||||
if err != nil {
|
||||
shared.LogError(tag, "UAPIOpen: %v", err)
|
||||
uapiFile, uapiErr := ipc.UAPIOpen(uapiPath, name)
|
||||
if uapiErr != nil {
|
||||
shared.LogError(tag, "UAPIOpen: %v", uapiErr)
|
||||
uapiFile = nil
|
||||
} else {
|
||||
uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile)
|
||||
if err != nil {
|
||||
uapiFile.Close()
|
||||
shared.LogError(tag, "UAPIListen: %v", err)
|
||||
uapiFile.Close()
|
||||
uapiFile = nil
|
||||
uapi = nil
|
||||
} else {
|
||||
go func() {
|
||||
for {
|
||||
@@ -113,7 +120,13 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
|
||||
err = dev.Up()
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Failed to bring up device", err)
|
||||
uapiFile.Close()
|
||||
if uapiFile != nil {
|
||||
uapiFile.Close()
|
||||
}
|
||||
if uapi != nil {
|
||||
uapi.Close()
|
||||
}
|
||||
shared.ReleaseHandle(handle)
|
||||
dev.Close()
|
||||
return -1
|
||||
}
|
||||
@@ -127,16 +140,12 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
|
||||
PingRecord: make(map[string]uint64),
|
||||
PingRecordLock: new(sync.Mutex),
|
||||
}
|
||||
|
||||
virtualTunnelHandles[handle] = virtualTun
|
||||
|
||||
// Create cancellable context
|
||||
tunnelCtx, tunnelCancel := context.WithCancel(context.Background())
|
||||
cancelFuncs[handle] = tunnelCancel
|
||||
|
||||
// Spawn all routines with context
|
||||
for _, spawner := range conf.Routines {
|
||||
shared.LogDebug(tag, "Spawning routine..")
|
||||
go func(s wireproxyawg.RoutineSpawner) {
|
||||
if err := s.SpawnRoutine(tunnelCtx, virtualTun); err != nil {
|
||||
shared.LogError(tag, "Routine failed: %v", err)
|
||||
@@ -144,7 +153,7 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
|
||||
}(spawner)
|
||||
}
|
||||
|
||||
shared.LogDebug(tag, "Done starting proxy and tunnel")
|
||||
shared.LogDebug(tag, "Done starting proxy and tunnel for handle %d", handle)
|
||||
return handle
|
||||
}
|
||||
|
||||
@@ -219,27 +228,29 @@ func awgTurnProxyTunnelOff(virtualTunnelHandle int32) {
|
||||
}
|
||||
shared.LogDebug(tag, "Tearing down tunnel %d", virtualTunnelHandle)
|
||||
|
||||
delete(virtualTunnelHandles, virtualTunnelHandle)
|
||||
|
||||
if cancel, exists := cancelFuncs[virtualTunnelHandle]; exists {
|
||||
cancel()
|
||||
delete(cancelFuncs, virtualTunnelHandle)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Close UAPI listener and underlying file
|
||||
if virtualTun.Uapi != nil {
|
||||
virtualTun.Uapi.Close()
|
||||
}
|
||||
|
||||
if virtualTun.Dev != nil {
|
||||
virtualTun.Dev.Close()
|
||||
}
|
||||
|
||||
go C.awgNotifyStatus(
|
||||
C.awgNotifyStatus(
|
||||
C.int32_t(virtualTunnelHandle),
|
||||
C.int32_t(shared.StatusStop),
|
||||
)
|
||||
|
||||
delete(virtualTunnelHandles, virtualTunnelHandle)
|
||||
// Give time for full resource release
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
shared.ReleaseHandle(virtualTunnelHandle)
|
||||
shared.LogDebug(tag, "Tunnel %d fully closed (UAPI/Dev/Bind purged)", virtualTunnelHandle)
|
||||
|
||||
shared.LogDebug(tag, "Tunnel handle %d fully closed", virtualTunnelHandle)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||
@@ -40,7 +41,6 @@ func init() {
|
||||
//export awgTurnOn
|
||||
func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath string) int32 {
|
||||
tunnel, name, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
|
||||
|
||||
if err != nil {
|
||||
unix.Close(int(tunFd))
|
||||
shared.LogError(tag, "CreateUnmonitoredTUNFromFD: %v", err)
|
||||
@@ -50,52 +50,56 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri
|
||||
conf, err := wireproxyawg.ParseConfigString(settings)
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Invalid config file", err)
|
||||
unix.Close(int(tunFd))
|
||||
if tunnel != nil {
|
||||
tunnel.Close()
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
shared.LogDebug(tag, "Creating device with domain blocking enabled: %v", conf.Device.DomainBlockingEnabled)
|
||||
|
||||
handle, err2 := shared.GenerateUniqueHandle()
|
||||
handle, err := shared.GenerateUniqueHandle()
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Unable to generate handle: %v", err)
|
||||
if tunnel != nil {
|
||||
tunnel.Close()
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
statusCB := func(code device.StatusCode) {
|
||||
go C.awgNotifyStatus(C.int32_t(handle), C.int32_t(code))
|
||||
}
|
||||
|
||||
tunDevice := device.NewDevice(tunnel, conn.NewStdNetBind(), shared.NewLogger("Tun/"+interfaceName), statusCB)
|
||||
|
||||
tunDevice.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
ipcRequest, err := wireproxyawg.CreateIPCRequest(conf.Device, false)
|
||||
if err != nil {
|
||||
shared.LogError(tag, "CreateIPCRequest: %v", err)
|
||||
unix.Close(int(tunFd))
|
||||
shared.ReleaseHandle(handle)
|
||||
tunDevice.Close()
|
||||
return -1
|
||||
}
|
||||
|
||||
err = tunDevice.IpcSet(ipcRequest.IpcRequest)
|
||||
if err != nil {
|
||||
unix.Close(int(tunFd))
|
||||
shared.ReleaseHandle(handle)
|
||||
shared.LogError(tag, "IpcSet: %v", err)
|
||||
shared.ReleaseHandle(handle)
|
||||
tunDevice.Close()
|
||||
return -1
|
||||
}
|
||||
|
||||
var uapi net.Listener
|
||||
|
||||
uapiFile, err := ipc.UAPIOpen(uapiPath, name)
|
||||
|
||||
if err != nil {
|
||||
shared.LogError(tag, "UAPIOpen: %v", err)
|
||||
uapiFile, uapiErr := ipc.UAPIOpen(uapiPath, name)
|
||||
if uapiErr != nil {
|
||||
shared.LogError(tag, "UAPIOpen: %v", uapiErr)
|
||||
uapiFile = nil
|
||||
} else {
|
||||
uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile) // uapiPath as rootdir, name as interface
|
||||
uapi, err = ipc.UAPIListen(uapiPath, name, uapiFile)
|
||||
if err != nil {
|
||||
uapiFile.Close()
|
||||
shared.LogError(tag, "UAPIListen: %v", err)
|
||||
uapiFile.Close()
|
||||
uapiFile = nil
|
||||
uapi = nil
|
||||
} else {
|
||||
go func() {
|
||||
for {
|
||||
@@ -112,23 +116,20 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri
|
||||
err = tunDevice.Up()
|
||||
if err != nil {
|
||||
shared.LogError(tag, "Unable to bring up device: %v", err)
|
||||
uapiFile.Close()
|
||||
if uapiFile != nil {
|
||||
uapiFile.Close()
|
||||
}
|
||||
if uapi != nil {
|
||||
uapi.Close()
|
||||
}
|
||||
shared.ReleaseHandle(handle)
|
||||
tunDevice.Close()
|
||||
return -1
|
||||
}
|
||||
shared.LogDebug(tag, "Device started")
|
||||
|
||||
if err2 != nil {
|
||||
shared.LogError(tag, "Unable to find empty handle", err2)
|
||||
uapiFile.Close()
|
||||
shared.ReleaseHandle(handle)
|
||||
tunDevice.Close()
|
||||
return -1
|
||||
}
|
||||
shared.LogDebug(tag, "Tunnel started successfully for handle %d", handle)
|
||||
|
||||
tunnelHandles[handle] = TunnelHandle{device: tunDevice, uapi: uapi}
|
||||
|
||||
return handle
|
||||
}
|
||||
|
||||
@@ -158,7 +159,7 @@ func awgUpdateTunnelPeers(tunnelHandle int32, settings string) int32 {
|
||||
return -1
|
||||
}
|
||||
|
||||
shared.LogDebug(tag, "Configuration updated successfully")
|
||||
shared.LogDebug(tag, "Configuration updated successfully with handle %d", handle)
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -170,16 +171,20 @@ func awgTurnOff(tunnelHandle int32) {
|
||||
return
|
||||
}
|
||||
|
||||
go C.awgNotifyStatus(
|
||||
C.int32_t(tunnelHandle),
|
||||
C.int32_t(shared.StatusStop),
|
||||
)
|
||||
|
||||
delete(tunnelHandles, tunnelHandle)
|
||||
|
||||
if handle.uapi != nil {
|
||||
handle.uapi.Close()
|
||||
}
|
||||
handle.device.Close()
|
||||
if handle.device != nil {
|
||||
handle.device.Close()
|
||||
}
|
||||
|
||||
C.awgNotifyStatus(C.int32_t(tunnelHandle), C.int32_t(shared.StatusStop))
|
||||
|
||||
// Give time for full resource release
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
shared.ReleaseHandle(tunnelHandle)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user