From 9ee1fa69ed4bf596121ca4d66114595b4f76e76c Mon Sep 17 00:00:00 2001 From: zaneschepke Date: Wed, 17 Jun 2026 02:32:55 -0400 Subject: [PATCH] fix: tunnel sockets protection race This race was especially impacting GrapheneOS devices #1274 --- .../java/com/zaneschepke/tunnel/VpnBackend.kt | 4 - .../tunnel/backend/ServiceHolder.kt | 80 ++++++++----------- .../tunnel/backend/WireGuardTunnelEngine.kt | 34 +++++--- .../tunnel/service/TunnelService.kt | 2 +- .../zaneschepke/tunnel/service/VpnService.kt | 5 +- tunnel/tools/libwg-go/proxy/proxy.go | 22 +---- tunnel/tools/libwg-go/shared/protect.go | 24 ++++++ tunnel/tools/libwg-go/vpn/vpn.go | 50 +----------- tunnel/tools/libwg-go/vpn/vpn_jni.c | 12 --- 9 files changed, 84 insertions(+), 149 deletions(-) create mode 100644 tunnel/tools/libwg-go/shared/protect.go diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/VpnBackend.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/VpnBackend.kt index 577af333..4f473931 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/VpnBackend.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/VpnBackend.kt @@ -15,10 +15,6 @@ internal object VpnBackend { external fun awgGetConfig(handle: Int): String? - external fun awgGetSocketV4(handle: Int): Int - - external fun awgGetSocketV6(handle: Int): Int - external fun awgTurnOff(handle: Int) external fun awgTurnOn(ifName: String, tunFd: Int, settings: String, uapiPath: String): Int diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/ServiceHolder.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/ServiceHolder.kt index db947181..1e0d8523 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/ServiceHolder.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/ServiceHolder.kt @@ -2,14 +2,18 @@ package com.zaneschepke.tunnel.backend import android.content.Context import android.content.Intent +import com.zaneschepke.tunnel.ProxyBackend import com.zaneschepke.tunnel.service.TunnelService import com.zaneschepke.tunnel.service.VpnService import com.zaneschepke.tunnel.util.BackendException import java.lang.ref.WeakReference import kotlin.time.Duration.Companion.milliseconds -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.filterNotNull +import kotlinx.coroutines.flow.first import kotlinx.coroutines.withTimeout import kotlinx.coroutines.withTimeoutOrNull import timber.log.Timber @@ -18,89 +22,71 @@ internal class ServiceHolder(val context: Context) { internal val uapiPath = context.dataDir.absolutePath - @Volatile private var vpnService = CompletableDeferred() - @Volatile private var tunnelService = CompletableDeferred() - @Volatile private var vpnServiceDestroyed = CompletableDeferred() - @Volatile private var tunnelServiceDestroyed = CompletableDeferred() + private val _vpnService = MutableStateFlow(null) + val vpnServiceFlow: StateFlow = _vpnService.asStateFlow() + private val _tunnelService = MutableStateFlow(null) + val tunnelServiceFlow: StateFlow = _tunnelService.asStateFlow() fun set(service: VpnService) { - vpnService.complete(service) + _vpnService.value = service + ProxyBackend.setSocketProtector(service) } fun set(service: TunnelService) { - tunnelService.complete(service) + _tunnelService.value = service } - fun signalVpnServiceDestroyed() { - vpnServiceDestroyed.complete(Unit) + fun clearVpnService() { + ProxyBackend.setSocketProtector(null) + _vpnService.value = null } - fun signalTunnelServiceDestroyed() { - tunnelServiceDestroyed.complete(Unit) + fun clearTunnelService() { + _tunnelService.value = null } - @OptIn(ExperimentalCoroutinesApi::class) suspend fun getVpnService(): VpnService { - if (vpnService.isCompleted && !vpnService.isCancelled) { - return vpnService.getCompleted() - } - if (android.net.VpnService.prepare(context) != null) { throw BackendException.Unauthorized("Permission unavailable to use VpnService") } - context.startForegroundService(Intent(context, VpnService::class.java)) + if (_vpnService.value == null) { + context.startForegroundService(Intent(context, VpnService::class.java)) + } return try { - withTimeout(3_000L.milliseconds) { vpnService.await() } + withTimeout(3_000L.milliseconds) { vpnServiceFlow.filterNotNull().first() } } catch (e: TimeoutCancellationException) { Timber.e(e, "Timed out getting VpnService") throw BackendException.InternalError("Failed to get VpnService") } } - @OptIn(ExperimentalCoroutinesApi::class) suspend fun getTunnelService(): TunnelService { - if (tunnelService.isCompleted && !tunnelService.isCancelled) { - return tunnelService.getCompleted() + if (_tunnelService.value == null) { + context.startForegroundService(Intent(context, TunnelService::class.java)) } - context.startForegroundService(Intent(context, TunnelService::class.java)) - return try { - withTimeout(3_000L.milliseconds) { tunnelService.await() } + withTimeout(3_000L.milliseconds) { tunnelServiceFlow.filterNotNull().first() } } catch (e: TimeoutCancellationException) { Timber.e(e, "Timed out getting TunnelService") throw BackendException.InternalError("Failed to get TunnelService") } } - @OptIn(ExperimentalCoroutinesApi::class) - suspend fun stopTunnelService() { - val service = - if (tunnelService.isCompleted && !tunnelService.isCancelled) { - tunnelService.getCompleted() - } else return - - tunnelServiceDestroyed = CompletableDeferred() - + suspend fun stopVpnService() { + val service = _vpnService.value ?: return + clearVpnService() service.shutdown() - tunnelService = CompletableDeferred() - withTimeoutOrNull(1_000L.milliseconds) { tunnelServiceDestroyed.await() } + withTimeoutOrNull(1_000L.milliseconds) { vpnServiceFlow.first { it == null } } } - @OptIn(ExperimentalCoroutinesApi::class) - suspend fun stopVpnService() { - val service = - if (vpnService.isCompleted && !vpnService.isCancelled) { - vpnService.getCompleted() - } else return - - vpnServiceDestroyed = CompletableDeferred() - + suspend fun stopTunnelService() { + val service = _tunnelService.value ?: return + clearTunnelService() service.shutdown() - vpnService = CompletableDeferred() - withTimeoutOrNull(1_000L.milliseconds) { vpnServiceDestroyed.await() } + withTimeoutOrNull(1_000L.milliseconds) { tunnelServiceFlow.first { it == null } } } companion object { diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/WireGuardTunnelEngine.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/WireGuardTunnelEngine.kt index c3066777..c8e73d93 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/WireGuardTunnelEngine.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/WireGuardTunnelEngine.kt @@ -176,8 +176,7 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) : } private suspend fun startVpnTunnel(tunnel: Tunnel, ifName: String, config: Config): Int { - - val service = serviceHolder.getVpnService() + val service = ensureVpnProtectorRegistered() val fd = service.createTunInterface(tunnel, config)?.detachFd() @@ -185,14 +184,9 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) : val handle = VpnBackend.awgTurnOn(ifName, fd, config.asQuickString(), serviceHolder.uapiPath) - if (handle < 0) { throw BackendException.InternalError("Internal native error with code: $handle") } - - service.protect(VpnBackend.awgGetSocketV4(handle)) - service.protect(VpnBackend.awgGetSocketV6(handle)) - return handle } @@ -224,13 +218,20 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) : proxyConfig: ProxyConfig, withBridge: Boolean, ): Int { - val quickConfig = buildProxiedQuickString(config, proxyConfig) if (!withBridge) { serviceHolder.getTunnelService() } + // Get VpnService and ensure protector is registered + val vpnService = + if (withBridge) { + ensureVpnProtectorRegistered() + } else { + null + } + val handle = ProxyBackend.awgStartProxy( ifName, @@ -238,12 +239,12 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) : serviceHolder.uapiPath, if (withBridge) 1 else 0, ) - if (handle < 0) { throw BackendException.InternalError("Internal native error") } - if (withBridge) { + // Start HEV bridge after the proxy tunnel is up + if (withBridge && vpnService != null) { val port = proxyConfig.socks5?.port ?: throw BackendException.InternalError( @@ -254,12 +255,23 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) : ?: throw BackendException.InternalError( "Bridge pass not set for kill switch proxy config" ) - serviceHolder.getVpnService().startHevSocks5Bridge(port, pass) + + vpnService.startHevSocks5Bridge(port, pass) } return handle } + /** + * Gets the VpnService and starts if needed while ensuring the protector is registered. This is + * needed before any native call that uses NewStdNetBindWithControl. + */ + private suspend fun ensureVpnProtectorRegistered(): VpnService { + val service = serviceHolder.getVpnService() + ProxyBackend.setSocketProtector(service) + return service + } + private fun buildProxiedQuickString(config: Config, proxyConfig: ProxyConfig): String { return buildString { append(config.asQuickString()) diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/service/TunnelService.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/service/TunnelService.kt index bb601e4d..31c7f0d1 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/service/TunnelService.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/service/TunnelService.kt @@ -55,7 +55,7 @@ class TunnelService : LifecycleService() { @OptIn(ExperimentalAtomicApi::class) override fun onDestroy() { ServiceCompat.stopForeground(this, ServiceCompat.STOP_FOREGROUND_REMOVE) - serviceHolder.signalTunnelServiceDestroyed() + serviceHolder.clearTunnelService() if (!userActivatedShutdown) { Timber.d("Service being killed by system, clean up tunnels") shutdownScope.launch { diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/service/VpnService.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/service/VpnService.kt index a696148e..961580b8 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/service/VpnService.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/service/VpnService.kt @@ -8,7 +8,6 @@ import android.system.OsConstants import androidx.core.app.ServiceCompat import com.zaneschepke.hevtunnel.HevTunnelConfig import com.zaneschepke.hevtunnel.TProxyService -import com.zaneschepke.tunnel.ProxyBackend import com.zaneschepke.tunnel.Tunnel import com.zaneschepke.tunnel.backend.Backend import com.zaneschepke.tunnel.backend.KillSwitch @@ -50,7 +49,6 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector { override fun onCreate() { serviceHolder.set(this) - ProxyBackend.setSocketProtector(this) launchForegroundNotification() super.onCreate() } @@ -68,7 +66,7 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector { override fun onDestroy() { Timber.d("VpnService destroyed") try { - ProxyBackend.setSocketProtector(null) + serviceHolder.clearVpnService() ServiceCompat.stopForeground(this, ServiceCompat.STOP_FOREGROUND_REMOVE) disableKillSwitch() hevBridgeJob?.cancel() @@ -83,7 +81,6 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector { } } } finally { - serviceHolder.signalVpnServiceDestroyed() super.onDestroy() } } diff --git a/tunnel/tools/libwg-go/proxy/proxy.go b/tunnel/tools/libwg-go/proxy/proxy.go index 35714361..d4c6763f 100644 --- a/tunnel/tools/libwg-go/proxy/proxy.go +++ b/tunnel/tools/libwg-go/proxy/proxy.go @@ -8,7 +8,6 @@ import ( "context" "net" "sync" - "syscall" "github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/device" @@ -18,8 +17,6 @@ import ( "github.com/wgtunnel/android/shared" ) -import "C" - var ( cancelFuncs map[int32]context.CancelFunc tag string @@ -76,7 +73,7 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass var bind conn.Bind if bypass == 1 { - bind = conn.NewStdNetBindWithControl(protectControlFunc) + bind = conn.NewStdNetBindWithControl(shared.ProtectControlFunc) } else { bind = conn.NewStdNetBind() } @@ -228,23 +225,6 @@ func awgGetProxyConfig(tunnelHandle int32) *C.char { return C.CString(settings) } -// control hook to bypass sockets -func protectControlFunc(network, address string, c syscall.RawConn) error { - var opErr error - err := c.Control(func(fd uintptr) { - if C.bypass_socket(C.int(fd)) == 0 { - opErr = syscall.EACCES - shared.LogError(tag, "Failed to protect socket FD: %d", fd) - } else { - shared.LogDebug(tag, "Protected socket FD: %d", fd) - } - }) - if err != nil { - return err - } - return opErr -} - //export awgTurnProxyTunnelOff func awgTurnProxyTunnelOff(virtualTunnelHandle int32) { diff --git a/tunnel/tools/libwg-go/shared/protect.go b/tunnel/tools/libwg-go/shared/protect.go new file mode 100644 index 00000000..59328beb --- /dev/null +++ b/tunnel/tools/libwg-go/shared/protect.go @@ -0,0 +1,24 @@ +package shared + +/* +#include "vpn_jni.h" +*/ +import "C" +import "syscall" + +// ProtectControlFunc control hook to bypass sockets +func ProtectControlFunc(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + if C.bypass_socket(C.int(fd)) == 0 { + opErr = syscall.EACCES + LogError("Protect", "Failed to protect socket FD: %d", fd) + } else { + LogDebug("Protect", "Protected socket FD: %d", fd) + } + }) + if err != nil { + return err + } + return opErr +} diff --git a/tunnel/tools/libwg-go/vpn/vpn.go b/tunnel/tools/libwg-go/vpn/vpn.go index ff11173a..14538b57 100644 --- a/tunnel/tools/libwg-go/vpn/vpn.go +++ b/tunnel/tools/libwg-go/vpn/vpn.go @@ -78,7 +78,7 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri go C.awgNotifyStatus(C.int32_t(handle), C.int32_t(code)) } - tunDevice := device.NewDevice(tunnel, conn.NewStdNetBind(), shared.NewLogger("Tun/"+interfaceName), statusCB) + tunDevice := device.NewDevice(tunnel, conn.NewStdNetBindWithControl(shared.ProtectControlFunc), shared.NewLogger("Tun/"+interfaceName), statusCB) tunDevice.DisableSomeRoamingForBrokenMobileSemantics() ipcRequest, err := wireproxyawg.CreateIPCRequest(conf.Device, false) @@ -213,54 +213,6 @@ func awgTurnOff(tunnelHandle int32) { ) } -//export awgGetSocketV4 -func awgGetSocketV4(tunnelHandle int32) int32 { - - tunnelMu.RLock() - handle, ok := tunnelHandles[tunnelHandle] - tunnelMu.RUnlock() - - if !ok { - return -1 - } - - bind, _ := handle.device.Bind().(conn.PeekLookAtSocketFd) - if bind == nil { - return -1 - } - - fd, err := bind.PeekLookAtSocketFd4() - if err != nil { - return -1 - } - - return int32(fd) -} - -//export awgGetSocketV6 -func awgGetSocketV6(tunnelHandle int32) int32 { - - tunnelMu.RLock() - handle, ok := tunnelHandles[tunnelHandle] - tunnelMu.RUnlock() - - if !ok { - return -1 - } - - bind, _ := handle.device.Bind().(conn.PeekLookAtSocketFd) - if bind == nil { - return -1 - } - - fd, err := bind.PeekLookAtSocketFd6() - if err != nil { - return -1 - } - - return int32(fd) -} - //export awgGetConfig func awgGetConfig(tunnelHandle int32) *C.char { diff --git a/tunnel/tools/libwg-go/vpn/vpn_jni.c b/tunnel/tools/libwg-go/vpn/vpn_jni.c index 2a51b2e9..333a651d 100644 --- a/tunnel/tools/libwg-go/vpn/vpn_jni.c +++ b/tunnel/tools/libwg-go/vpn/vpn_jni.c @@ -11,8 +11,6 @@ struct go_string { const char *str; long n; }; extern int awgTurnOn(struct go_string ifname, int tun_fd, struct go_string settings, struct go_string uapipath); extern void awgTurnOff(int handle); -extern int awgGetSocketV4(int handle); -extern int awgGetSocketV6(int handle); extern char *awgGetConfig(int handle); extern char *awgVersion(); extern int awgUpdateTunnelPeers(int handle, struct go_string settings); @@ -46,16 +44,6 @@ JNIEXPORT void JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgTurnOff(JNIEnv awgTurnOff(handle); } -JNIEXPORT jint JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetSocketV4(JNIEnv *env, jclass c, jint handle) -{ -return awgGetSocketV4(handle); -} - -JNIEXPORT jint JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetSocketV6(JNIEnv *env, jclass c, jint handle) -{ -return awgGetSocketV6(handle); -} - JNIEXPORT jstring JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetConfig(JNIEnv *env, jclass c, jint handle) { jstring ret;