fix: tunnel sockets protection race

This race was especially impacting GrapheneOS devices

#1274
This commit is contained in:
zaneschepke
2026-06-17 02:32:55 -04:00
parent 379ffdcbbf
commit 9ee1fa69ed
9 changed files with 84 additions and 149 deletions
@@ -15,10 +15,6 @@ internal object VpnBackend {
external fun awgGetConfig(handle: Int): String?
external fun awgGetSocketV4(handle: Int): Int
external fun awgGetSocketV6(handle: Int): Int
external fun awgTurnOff(handle: Int)
external fun awgTurnOn(ifName: String, tunFd: Int, settings: String, uapiPath: String): Int
@@ -2,14 +2,18 @@ package com.zaneschepke.tunnel.backend
import android.content.Context
import android.content.Intent
import com.zaneschepke.tunnel.ProxyBackend
import com.zaneschepke.tunnel.service.TunnelService
import com.zaneschepke.tunnel.service.VpnService
import com.zaneschepke.tunnel.util.BackendException
import java.lang.ref.WeakReference
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.withTimeout
import kotlinx.coroutines.withTimeoutOrNull
import timber.log.Timber
@@ -18,89 +22,71 @@ internal class ServiceHolder(val context: Context) {
internal val uapiPath = context.dataDir.absolutePath
@Volatile private var vpnService = CompletableDeferred<VpnService>()
@Volatile private var tunnelService = CompletableDeferred<TunnelService>()
@Volatile private var vpnServiceDestroyed = CompletableDeferred<Unit>()
@Volatile private var tunnelServiceDestroyed = CompletableDeferred<Unit>()
private val _vpnService = MutableStateFlow<VpnService?>(null)
val vpnServiceFlow: StateFlow<VpnService?> = _vpnService.asStateFlow()
private val _tunnelService = MutableStateFlow<TunnelService?>(null)
val tunnelServiceFlow: StateFlow<TunnelService?> = _tunnelService.asStateFlow()
fun set(service: VpnService) {
vpnService.complete(service)
_vpnService.value = service
ProxyBackend.setSocketProtector(service)
}
fun set(service: TunnelService) {
tunnelService.complete(service)
_tunnelService.value = service
}
fun signalVpnServiceDestroyed() {
vpnServiceDestroyed.complete(Unit)
fun clearVpnService() {
ProxyBackend.setSocketProtector(null)
_vpnService.value = null
}
fun signalTunnelServiceDestroyed() {
tunnelServiceDestroyed.complete(Unit)
fun clearTunnelService() {
_tunnelService.value = null
}
@OptIn(ExperimentalCoroutinesApi::class)
suspend fun getVpnService(): VpnService {
if (vpnService.isCompleted && !vpnService.isCancelled) {
return vpnService.getCompleted()
}
if (android.net.VpnService.prepare(context) != null) {
throw BackendException.Unauthorized("Permission unavailable to use VpnService")
}
context.startForegroundService(Intent(context, VpnService::class.java))
if (_vpnService.value == null) {
context.startForegroundService(Intent(context, VpnService::class.java))
}
return try {
withTimeout(3_000L.milliseconds) { vpnService.await() }
withTimeout(3_000L.milliseconds) { vpnServiceFlow.filterNotNull().first() }
} catch (e: TimeoutCancellationException) {
Timber.e(e, "Timed out getting VpnService")
throw BackendException.InternalError("Failed to get VpnService")
}
}
@OptIn(ExperimentalCoroutinesApi::class)
suspend fun getTunnelService(): TunnelService {
if (tunnelService.isCompleted && !tunnelService.isCancelled) {
return tunnelService.getCompleted()
if (_tunnelService.value == null) {
context.startForegroundService(Intent(context, TunnelService::class.java))
}
context.startForegroundService(Intent(context, TunnelService::class.java))
return try {
withTimeout(3_000L.milliseconds) { tunnelService.await() }
withTimeout(3_000L.milliseconds) { tunnelServiceFlow.filterNotNull().first() }
} catch (e: TimeoutCancellationException) {
Timber.e(e, "Timed out getting TunnelService")
throw BackendException.InternalError("Failed to get TunnelService")
}
}
@OptIn(ExperimentalCoroutinesApi::class)
suspend fun stopTunnelService() {
val service =
if (tunnelService.isCompleted && !tunnelService.isCancelled) {
tunnelService.getCompleted()
} else return
tunnelServiceDestroyed = CompletableDeferred()
suspend fun stopVpnService() {
val service = _vpnService.value ?: return
clearVpnService()
service.shutdown()
tunnelService = CompletableDeferred()
withTimeoutOrNull(1_000L.milliseconds) { tunnelServiceDestroyed.await() }
withTimeoutOrNull(1_000L.milliseconds) { vpnServiceFlow.first { it == null } }
}
@OptIn(ExperimentalCoroutinesApi::class)
suspend fun stopVpnService() {
val service =
if (vpnService.isCompleted && !vpnService.isCancelled) {
vpnService.getCompleted()
} else return
vpnServiceDestroyed = CompletableDeferred()
suspend fun stopTunnelService() {
val service = _tunnelService.value ?: return
clearTunnelService()
service.shutdown()
vpnService = CompletableDeferred()
withTimeoutOrNull(1_000L.milliseconds) { vpnServiceDestroyed.await() }
withTimeoutOrNull(1_000L.milliseconds) { tunnelServiceFlow.first { it == null } }
}
companion object {
@@ -176,8 +176,7 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
}
private suspend fun startVpnTunnel(tunnel: Tunnel, ifName: String, config: Config): Int {
val service = serviceHolder.getVpnService()
val service = ensureVpnProtectorRegistered()
val fd =
service.createTunInterface(tunnel, config)?.detachFd()
@@ -185,14 +184,9 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
val handle =
VpnBackend.awgTurnOn(ifName, fd, config.asQuickString(), serviceHolder.uapiPath)
if (handle < 0) {
throw BackendException.InternalError("Internal native error with code: $handle")
}
service.protect(VpnBackend.awgGetSocketV4(handle))
service.protect(VpnBackend.awgGetSocketV6(handle))
return handle
}
@@ -224,13 +218,20 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
proxyConfig: ProxyConfig,
withBridge: Boolean,
): Int {
val quickConfig = buildProxiedQuickString(config, proxyConfig)
if (!withBridge) {
serviceHolder.getTunnelService()
}
// Get VpnService and ensure protector is registered
val vpnService =
if (withBridge) {
ensureVpnProtectorRegistered()
} else {
null
}
val handle =
ProxyBackend.awgStartProxy(
ifName,
@@ -238,12 +239,12 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
serviceHolder.uapiPath,
if (withBridge) 1 else 0,
)
if (handle < 0) {
throw BackendException.InternalError("Internal native error")
}
if (withBridge) {
// Start HEV bridge after the proxy tunnel is up
if (withBridge && vpnService != null) {
val port =
proxyConfig.socks5?.port
?: throw BackendException.InternalError(
@@ -254,12 +255,23 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
?: throw BackendException.InternalError(
"Bridge pass not set for kill switch proxy config"
)
serviceHolder.getVpnService().startHevSocks5Bridge(port, pass)
vpnService.startHevSocks5Bridge(port, pass)
}
return handle
}
/**
* Gets the VpnService and starts if needed while ensuring the protector is registered. This is
* needed before any native call that uses NewStdNetBindWithControl.
*/
private suspend fun ensureVpnProtectorRegistered(): VpnService {
val service = serviceHolder.getVpnService()
ProxyBackend.setSocketProtector(service)
return service
}
private fun buildProxiedQuickString(config: Config, proxyConfig: ProxyConfig): String {
return buildString {
append(config.asQuickString())
@@ -55,7 +55,7 @@ class TunnelService : LifecycleService() {
@OptIn(ExperimentalAtomicApi::class)
override fun onDestroy() {
ServiceCompat.stopForeground(this, ServiceCompat.STOP_FOREGROUND_REMOVE)
serviceHolder.signalTunnelServiceDestroyed()
serviceHolder.clearTunnelService()
if (!userActivatedShutdown) {
Timber.d("Service being killed by system, clean up tunnels")
shutdownScope.launch {
@@ -8,7 +8,6 @@ import android.system.OsConstants
import androidx.core.app.ServiceCompat
import com.zaneschepke.hevtunnel.HevTunnelConfig
import com.zaneschepke.hevtunnel.TProxyService
import com.zaneschepke.tunnel.ProxyBackend
import com.zaneschepke.tunnel.Tunnel
import com.zaneschepke.tunnel.backend.Backend
import com.zaneschepke.tunnel.backend.KillSwitch
@@ -50,7 +49,6 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector {
override fun onCreate() {
serviceHolder.set(this)
ProxyBackend.setSocketProtector(this)
launchForegroundNotification()
super.onCreate()
}
@@ -68,7 +66,7 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector {
override fun onDestroy() {
Timber.d("VpnService destroyed")
try {
ProxyBackend.setSocketProtector(null)
serviceHolder.clearVpnService()
ServiceCompat.stopForeground(this, ServiceCompat.STOP_FOREGROUND_REMOVE)
disableKillSwitch()
hevBridgeJob?.cancel()
@@ -83,7 +81,6 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector {
}
}
} finally {
serviceHolder.signalVpnServiceDestroyed()
super.onDestroy()
}
}
+1 -21
View File
@@ -8,7 +8,6 @@ import (
"context"
"net"
"sync"
"syscall"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
@@ -18,8 +17,6 @@ import (
"github.com/wgtunnel/android/shared"
)
import "C"
var (
cancelFuncs map[int32]context.CancelFunc
tag string
@@ -76,7 +73,7 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
var bind conn.Bind
if bypass == 1 {
bind = conn.NewStdNetBindWithControl(protectControlFunc)
bind = conn.NewStdNetBindWithControl(shared.ProtectControlFunc)
} else {
bind = conn.NewStdNetBind()
}
@@ -228,23 +225,6 @@ func awgGetProxyConfig(tunnelHandle int32) *C.char {
return C.CString(settings)
}
// control hook to bypass sockets
func protectControlFunc(network, address string, c syscall.RawConn) error {
var opErr error
err := c.Control(func(fd uintptr) {
if C.bypass_socket(C.int(fd)) == 0 {
opErr = syscall.EACCES
shared.LogError(tag, "Failed to protect socket FD: %d", fd)
} else {
shared.LogDebug(tag, "Protected socket FD: %d", fd)
}
})
if err != nil {
return err
}
return opErr
}
//export awgTurnProxyTunnelOff
func awgTurnProxyTunnelOff(virtualTunnelHandle int32) {
+24
View File
@@ -0,0 +1,24 @@
package shared
/*
#include "vpn_jni.h"
*/
import "C"
import "syscall"
// ProtectControlFunc control hook to bypass sockets
func ProtectControlFunc(network, address string, c syscall.RawConn) error {
var opErr error
err := c.Control(func(fd uintptr) {
if C.bypass_socket(C.int(fd)) == 0 {
opErr = syscall.EACCES
LogError("Protect", "Failed to protect socket FD: %d", fd)
} else {
LogDebug("Protect", "Protected socket FD: %d", fd)
}
})
if err != nil {
return err
}
return opErr
}
+1 -49
View File
@@ -78,7 +78,7 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri
go C.awgNotifyStatus(C.int32_t(handle), C.int32_t(code))
}
tunDevice := device.NewDevice(tunnel, conn.NewStdNetBind(), shared.NewLogger("Tun/"+interfaceName), statusCB)
tunDevice := device.NewDevice(tunnel, conn.NewStdNetBindWithControl(shared.ProtectControlFunc), shared.NewLogger("Tun/"+interfaceName), statusCB)
tunDevice.DisableSomeRoamingForBrokenMobileSemantics()
ipcRequest, err := wireproxyawg.CreateIPCRequest(conf.Device, false)
@@ -213,54 +213,6 @@ func awgTurnOff(tunnelHandle int32) {
)
}
//export awgGetSocketV4
func awgGetSocketV4(tunnelHandle int32) int32 {
tunnelMu.RLock()
handle, ok := tunnelHandles[tunnelHandle]
tunnelMu.RUnlock()
if !ok {
return -1
}
bind, _ := handle.device.Bind().(conn.PeekLookAtSocketFd)
if bind == nil {
return -1
}
fd, err := bind.PeekLookAtSocketFd4()
if err != nil {
return -1
}
return int32(fd)
}
//export awgGetSocketV6
func awgGetSocketV6(tunnelHandle int32) int32 {
tunnelMu.RLock()
handle, ok := tunnelHandles[tunnelHandle]
tunnelMu.RUnlock()
if !ok {
return -1
}
bind, _ := handle.device.Bind().(conn.PeekLookAtSocketFd)
if bind == nil {
return -1
}
fd, err := bind.PeekLookAtSocketFd6()
if err != nil {
return -1
}
return int32(fd)
}
//export awgGetConfig
func awgGetConfig(tunnelHandle int32) *C.char {
-12
View File
@@ -11,8 +11,6 @@
struct go_string { const char *str; long n; };
extern int awgTurnOn(struct go_string ifname, int tun_fd, struct go_string settings, struct go_string uapipath);
extern void awgTurnOff(int handle);
extern int awgGetSocketV4(int handle);
extern int awgGetSocketV6(int handle);
extern char *awgGetConfig(int handle);
extern char *awgVersion();
extern int awgUpdateTunnelPeers(int handle, struct go_string settings);
@@ -46,16 +44,6 @@ JNIEXPORT void JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgTurnOff(JNIEnv
awgTurnOff(handle);
}
JNIEXPORT jint JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetSocketV4(JNIEnv *env, jclass c, jint handle)
{
return awgGetSocketV4(handle);
}
JNIEXPORT jint JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetSocketV6(JNIEnv *env, jclass c, jint handle)
{
return awgGetSocketV6(handle);
}
JNIEXPORT jstring JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetConfig(JNIEnv *env, jclass c, jint handle)
{
jstring ret;