mirror of
https://github.com/wgtunnel/android.git
synced 2026-07-03 14:07:49 +02:00
fix: tunnel sockets protection race
This race was especially impacting GrapheneOS devices #1274
This commit is contained in:
@@ -15,10 +15,6 @@ internal object VpnBackend {
|
||||
|
||||
external fun awgGetConfig(handle: Int): String?
|
||||
|
||||
external fun awgGetSocketV4(handle: Int): Int
|
||||
|
||||
external fun awgGetSocketV6(handle: Int): Int
|
||||
|
||||
external fun awgTurnOff(handle: Int)
|
||||
|
||||
external fun awgTurnOn(ifName: String, tunFd: Int, settings: String, uapiPath: String): Int
|
||||
|
||||
@@ -2,14 +2,18 @@ package com.zaneschepke.tunnel.backend
|
||||
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import com.zaneschepke.tunnel.ProxyBackend
|
||||
import com.zaneschepke.tunnel.service.TunnelService
|
||||
import com.zaneschepke.tunnel.service.VpnService
|
||||
import com.zaneschepke.tunnel.util.BackendException
|
||||
import java.lang.ref.WeakReference
|
||||
import kotlin.time.Duration.Companion.milliseconds
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.TimeoutCancellationException
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.filterNotNull
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.withTimeout
|
||||
import kotlinx.coroutines.withTimeoutOrNull
|
||||
import timber.log.Timber
|
||||
@@ -18,89 +22,71 @@ internal class ServiceHolder(val context: Context) {
|
||||
|
||||
internal val uapiPath = context.dataDir.absolutePath
|
||||
|
||||
@Volatile private var vpnService = CompletableDeferred<VpnService>()
|
||||
@Volatile private var tunnelService = CompletableDeferred<TunnelService>()
|
||||
@Volatile private var vpnServiceDestroyed = CompletableDeferred<Unit>()
|
||||
@Volatile private var tunnelServiceDestroyed = CompletableDeferred<Unit>()
|
||||
private val _vpnService = MutableStateFlow<VpnService?>(null)
|
||||
val vpnServiceFlow: StateFlow<VpnService?> = _vpnService.asStateFlow()
|
||||
private val _tunnelService = MutableStateFlow<TunnelService?>(null)
|
||||
val tunnelServiceFlow: StateFlow<TunnelService?> = _tunnelService.asStateFlow()
|
||||
|
||||
fun set(service: VpnService) {
|
||||
vpnService.complete(service)
|
||||
_vpnService.value = service
|
||||
ProxyBackend.setSocketProtector(service)
|
||||
}
|
||||
|
||||
fun set(service: TunnelService) {
|
||||
tunnelService.complete(service)
|
||||
_tunnelService.value = service
|
||||
}
|
||||
|
||||
fun signalVpnServiceDestroyed() {
|
||||
vpnServiceDestroyed.complete(Unit)
|
||||
fun clearVpnService() {
|
||||
ProxyBackend.setSocketProtector(null)
|
||||
_vpnService.value = null
|
||||
}
|
||||
|
||||
fun signalTunnelServiceDestroyed() {
|
||||
tunnelServiceDestroyed.complete(Unit)
|
||||
fun clearTunnelService() {
|
||||
_tunnelService.value = null
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
suspend fun getVpnService(): VpnService {
|
||||
if (vpnService.isCompleted && !vpnService.isCancelled) {
|
||||
return vpnService.getCompleted()
|
||||
}
|
||||
|
||||
if (android.net.VpnService.prepare(context) != null) {
|
||||
throw BackendException.Unauthorized("Permission unavailable to use VpnService")
|
||||
}
|
||||
|
||||
context.startForegroundService(Intent(context, VpnService::class.java))
|
||||
if (_vpnService.value == null) {
|
||||
context.startForegroundService(Intent(context, VpnService::class.java))
|
||||
}
|
||||
|
||||
return try {
|
||||
withTimeout(3_000L.milliseconds) { vpnService.await() }
|
||||
withTimeout(3_000L.milliseconds) { vpnServiceFlow.filterNotNull().first() }
|
||||
} catch (e: TimeoutCancellationException) {
|
||||
Timber.e(e, "Timed out getting VpnService")
|
||||
throw BackendException.InternalError("Failed to get VpnService")
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
suspend fun getTunnelService(): TunnelService {
|
||||
if (tunnelService.isCompleted && !tunnelService.isCancelled) {
|
||||
return tunnelService.getCompleted()
|
||||
if (_tunnelService.value == null) {
|
||||
context.startForegroundService(Intent(context, TunnelService::class.java))
|
||||
}
|
||||
|
||||
context.startForegroundService(Intent(context, TunnelService::class.java))
|
||||
|
||||
return try {
|
||||
withTimeout(3_000L.milliseconds) { tunnelService.await() }
|
||||
withTimeout(3_000L.milliseconds) { tunnelServiceFlow.filterNotNull().first() }
|
||||
} catch (e: TimeoutCancellationException) {
|
||||
Timber.e(e, "Timed out getting TunnelService")
|
||||
throw BackendException.InternalError("Failed to get TunnelService")
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
suspend fun stopTunnelService() {
|
||||
val service =
|
||||
if (tunnelService.isCompleted && !tunnelService.isCancelled) {
|
||||
tunnelService.getCompleted()
|
||||
} else return
|
||||
|
||||
tunnelServiceDestroyed = CompletableDeferred()
|
||||
|
||||
suspend fun stopVpnService() {
|
||||
val service = _vpnService.value ?: return
|
||||
clearVpnService()
|
||||
service.shutdown()
|
||||
tunnelService = CompletableDeferred()
|
||||
withTimeoutOrNull(1_000L.milliseconds) { tunnelServiceDestroyed.await() }
|
||||
withTimeoutOrNull(1_000L.milliseconds) { vpnServiceFlow.first { it == null } }
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
suspend fun stopVpnService() {
|
||||
val service =
|
||||
if (vpnService.isCompleted && !vpnService.isCancelled) {
|
||||
vpnService.getCompleted()
|
||||
} else return
|
||||
|
||||
vpnServiceDestroyed = CompletableDeferred()
|
||||
|
||||
suspend fun stopTunnelService() {
|
||||
val service = _tunnelService.value ?: return
|
||||
clearTunnelService()
|
||||
service.shutdown()
|
||||
vpnService = CompletableDeferred()
|
||||
withTimeoutOrNull(1_000L.milliseconds) { vpnServiceDestroyed.await() }
|
||||
withTimeoutOrNull(1_000L.milliseconds) { tunnelServiceFlow.first { it == null } }
|
||||
}
|
||||
|
||||
companion object {
|
||||
|
||||
@@ -176,8 +176,7 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
|
||||
}
|
||||
|
||||
private suspend fun startVpnTunnel(tunnel: Tunnel, ifName: String, config: Config): Int {
|
||||
|
||||
val service = serviceHolder.getVpnService()
|
||||
val service = ensureVpnProtectorRegistered()
|
||||
|
||||
val fd =
|
||||
service.createTunInterface(tunnel, config)?.detachFd()
|
||||
@@ -185,14 +184,9 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
|
||||
|
||||
val handle =
|
||||
VpnBackend.awgTurnOn(ifName, fd, config.asQuickString(), serviceHolder.uapiPath)
|
||||
|
||||
if (handle < 0) {
|
||||
throw BackendException.InternalError("Internal native error with code: $handle")
|
||||
}
|
||||
|
||||
service.protect(VpnBackend.awgGetSocketV4(handle))
|
||||
service.protect(VpnBackend.awgGetSocketV6(handle))
|
||||
|
||||
return handle
|
||||
}
|
||||
|
||||
@@ -224,13 +218,20 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
|
||||
proxyConfig: ProxyConfig,
|
||||
withBridge: Boolean,
|
||||
): Int {
|
||||
|
||||
val quickConfig = buildProxiedQuickString(config, proxyConfig)
|
||||
|
||||
if (!withBridge) {
|
||||
serviceHolder.getTunnelService()
|
||||
}
|
||||
|
||||
// Get VpnService and ensure protector is registered
|
||||
val vpnService =
|
||||
if (withBridge) {
|
||||
ensureVpnProtectorRegistered()
|
||||
} else {
|
||||
null
|
||||
}
|
||||
|
||||
val handle =
|
||||
ProxyBackend.awgStartProxy(
|
||||
ifName,
|
||||
@@ -238,12 +239,12 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
|
||||
serviceHolder.uapiPath,
|
||||
if (withBridge) 1 else 0,
|
||||
)
|
||||
|
||||
if (handle < 0) {
|
||||
throw BackendException.InternalError("Internal native error")
|
||||
}
|
||||
|
||||
if (withBridge) {
|
||||
// Start HEV bridge after the proxy tunnel is up
|
||||
if (withBridge && vpnService != null) {
|
||||
val port =
|
||||
proxyConfig.socks5?.port
|
||||
?: throw BackendException.InternalError(
|
||||
@@ -254,12 +255,23 @@ internal class WireGuardTunnelEngine(private val serviceHolder: ServiceHolder) :
|
||||
?: throw BackendException.InternalError(
|
||||
"Bridge pass not set for kill switch proxy config"
|
||||
)
|
||||
serviceHolder.getVpnService().startHevSocks5Bridge(port, pass)
|
||||
|
||||
vpnService.startHevSocks5Bridge(port, pass)
|
||||
}
|
||||
|
||||
return handle
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the VpnService and starts if needed while ensuring the protector is registered. This is
|
||||
* needed before any native call that uses NewStdNetBindWithControl.
|
||||
*/
|
||||
private suspend fun ensureVpnProtectorRegistered(): VpnService {
|
||||
val service = serviceHolder.getVpnService()
|
||||
ProxyBackend.setSocketProtector(service)
|
||||
return service
|
||||
}
|
||||
|
||||
private fun buildProxiedQuickString(config: Config, proxyConfig: ProxyConfig): String {
|
||||
return buildString {
|
||||
append(config.asQuickString())
|
||||
|
||||
@@ -55,7 +55,7 @@ class TunnelService : LifecycleService() {
|
||||
@OptIn(ExperimentalAtomicApi::class)
|
||||
override fun onDestroy() {
|
||||
ServiceCompat.stopForeground(this, ServiceCompat.STOP_FOREGROUND_REMOVE)
|
||||
serviceHolder.signalTunnelServiceDestroyed()
|
||||
serviceHolder.clearTunnelService()
|
||||
if (!userActivatedShutdown) {
|
||||
Timber.d("Service being killed by system, clean up tunnels")
|
||||
shutdownScope.launch {
|
||||
|
||||
@@ -8,7 +8,6 @@ import android.system.OsConstants
|
||||
import androidx.core.app.ServiceCompat
|
||||
import com.zaneschepke.hevtunnel.HevTunnelConfig
|
||||
import com.zaneschepke.hevtunnel.TProxyService
|
||||
import com.zaneschepke.tunnel.ProxyBackend
|
||||
import com.zaneschepke.tunnel.Tunnel
|
||||
import com.zaneschepke.tunnel.backend.Backend
|
||||
import com.zaneschepke.tunnel.backend.KillSwitch
|
||||
@@ -50,7 +49,6 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector {
|
||||
|
||||
override fun onCreate() {
|
||||
serviceHolder.set(this)
|
||||
ProxyBackend.setSocketProtector(this)
|
||||
launchForegroundNotification()
|
||||
super.onCreate()
|
||||
}
|
||||
@@ -68,7 +66,7 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector {
|
||||
override fun onDestroy() {
|
||||
Timber.d("VpnService destroyed")
|
||||
try {
|
||||
ProxyBackend.setSocketProtector(null)
|
||||
serviceHolder.clearVpnService()
|
||||
ServiceCompat.stopForeground(this, ServiceCompat.STOP_FOREGROUND_REMOVE)
|
||||
disableKillSwitch()
|
||||
hevBridgeJob?.cancel()
|
||||
@@ -83,7 +81,6 @@ class VpnService : android.net.VpnService(), KillSwitch, SocketProtector {
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
serviceHolder.signalVpnServiceDestroyed()
|
||||
super.onDestroy()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||
@@ -18,8 +17,6 @@ import (
|
||||
"github.com/wgtunnel/android/shared"
|
||||
)
|
||||
|
||||
import "C"
|
||||
|
||||
var (
|
||||
cancelFuncs map[int32]context.CancelFunc
|
||||
tag string
|
||||
@@ -76,7 +73,7 @@ func awgStartProxy(interfaceName string, config string, uapiPath string, bypass
|
||||
|
||||
var bind conn.Bind
|
||||
if bypass == 1 {
|
||||
bind = conn.NewStdNetBindWithControl(protectControlFunc)
|
||||
bind = conn.NewStdNetBindWithControl(shared.ProtectControlFunc)
|
||||
} else {
|
||||
bind = conn.NewStdNetBind()
|
||||
}
|
||||
@@ -228,23 +225,6 @@ func awgGetProxyConfig(tunnelHandle int32) *C.char {
|
||||
return C.CString(settings)
|
||||
}
|
||||
|
||||
// control hook to bypass sockets
|
||||
func protectControlFunc(network, address string, c syscall.RawConn) error {
|
||||
var opErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
if C.bypass_socket(C.int(fd)) == 0 {
|
||||
opErr = syscall.EACCES
|
||||
shared.LogError(tag, "Failed to protect socket FD: %d", fd)
|
||||
} else {
|
||||
shared.LogDebug(tag, "Protected socket FD: %d", fd)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return opErr
|
||||
}
|
||||
|
||||
//export awgTurnProxyTunnelOff
|
||||
func awgTurnProxyTunnelOff(virtualTunnelHandle int32) {
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package shared
|
||||
|
||||
/*
|
||||
#include "vpn_jni.h"
|
||||
*/
|
||||
import "C"
|
||||
import "syscall"
|
||||
|
||||
// ProtectControlFunc control hook to bypass sockets
|
||||
func ProtectControlFunc(network, address string, c syscall.RawConn) error {
|
||||
var opErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
if C.bypass_socket(C.int(fd)) == 0 {
|
||||
opErr = syscall.EACCES
|
||||
LogError("Protect", "Failed to protect socket FD: %d", fd)
|
||||
} else {
|
||||
LogDebug("Protect", "Protected socket FD: %d", fd)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return opErr
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func awgTurnOn(interfaceName string, tunFd int32, settings string, uapiPath stri
|
||||
go C.awgNotifyStatus(C.int32_t(handle), C.int32_t(code))
|
||||
}
|
||||
|
||||
tunDevice := device.NewDevice(tunnel, conn.NewStdNetBind(), shared.NewLogger("Tun/"+interfaceName), statusCB)
|
||||
tunDevice := device.NewDevice(tunnel, conn.NewStdNetBindWithControl(shared.ProtectControlFunc), shared.NewLogger("Tun/"+interfaceName), statusCB)
|
||||
tunDevice.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
||||
ipcRequest, err := wireproxyawg.CreateIPCRequest(conf.Device, false)
|
||||
@@ -213,54 +213,6 @@ func awgTurnOff(tunnelHandle int32) {
|
||||
)
|
||||
}
|
||||
|
||||
//export awgGetSocketV4
|
||||
func awgGetSocketV4(tunnelHandle int32) int32 {
|
||||
|
||||
tunnelMu.RLock()
|
||||
handle, ok := tunnelHandles[tunnelHandle]
|
||||
tunnelMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
|
||||
bind, _ := handle.device.Bind().(conn.PeekLookAtSocketFd)
|
||||
if bind == nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
fd, err := bind.PeekLookAtSocketFd4()
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
return int32(fd)
|
||||
}
|
||||
|
||||
//export awgGetSocketV6
|
||||
func awgGetSocketV6(tunnelHandle int32) int32 {
|
||||
|
||||
tunnelMu.RLock()
|
||||
handle, ok := tunnelHandles[tunnelHandle]
|
||||
tunnelMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
|
||||
bind, _ := handle.device.Bind().(conn.PeekLookAtSocketFd)
|
||||
if bind == nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
fd, err := bind.PeekLookAtSocketFd6()
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
return int32(fd)
|
||||
}
|
||||
|
||||
//export awgGetConfig
|
||||
func awgGetConfig(tunnelHandle int32) *C.char {
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
struct go_string { const char *str; long n; };
|
||||
extern int awgTurnOn(struct go_string ifname, int tun_fd, struct go_string settings, struct go_string uapipath);
|
||||
extern void awgTurnOff(int handle);
|
||||
extern int awgGetSocketV4(int handle);
|
||||
extern int awgGetSocketV6(int handle);
|
||||
extern char *awgGetConfig(int handle);
|
||||
extern char *awgVersion();
|
||||
extern int awgUpdateTunnelPeers(int handle, struct go_string settings);
|
||||
@@ -46,16 +44,6 @@ JNIEXPORT void JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgTurnOff(JNIEnv
|
||||
awgTurnOff(handle);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetSocketV4(JNIEnv *env, jclass c, jint handle)
|
||||
{
|
||||
return awgGetSocketV4(handle);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetSocketV6(JNIEnv *env, jclass c, jint handle)
|
||||
{
|
||||
return awgGetSocketV6(handle);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL Java_com_zaneschepke_tunnel_VpnBackend_awgGetConfig(JNIEnv *env, jclass c, jint handle)
|
||||
{
|
||||
jstring ret;
|
||||
|
||||
Reference in New Issue
Block a user