mirror of
https://github.com/wgtunnel/android.git
synced 2026-07-03 14:07:49 +02:00
fix: private dns to use network bind, bootstrap custom with system dns
closes #1312 closes #1311 #1303 #1270
This commit is contained in:
@@ -81,7 +81,7 @@ object DnsValidator {
|
||||
return Result.Valid
|
||||
}
|
||||
|
||||
private fun validateUdp(value: String): DnsValidator.Result {
|
||||
private fun validateUdp(value: String): Result {
|
||||
val parts = value.split(":")
|
||||
|
||||
val host = parts.getOrNull(0)?.trim()
|
||||
@@ -93,14 +93,14 @@ object DnsValidator {
|
||||
|
||||
// basic IP/hostname sanity check
|
||||
if (!isValidHostOrIp(host)) {
|
||||
return DnsValidator.Result.Invalid(DnsError.InvalidIpOrHost)
|
||||
return Result.Invalid(DnsError.InvalidIpOrHost)
|
||||
}
|
||||
|
||||
if (port !in 1..65535) {
|
||||
return DnsValidator.Result.Invalid(DnsError.InvalidPort)
|
||||
return Result.Invalid(DnsError.InvalidPort)
|
||||
}
|
||||
|
||||
return DnsValidator.Result.Valid
|
||||
return Result.Valid
|
||||
}
|
||||
|
||||
private fun isValidHostOrIp(value: String): Boolean {
|
||||
|
||||
@@ -66,6 +66,7 @@ dependencies {
|
||||
|
||||
api(libs.amneziawg.parser)
|
||||
implementation(libs.libsu)
|
||||
implementation(libs.ipaddress)
|
||||
|
||||
implementation(libs.timber)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import com.zaneschepke.tunnel.VpnBackend
|
||||
import com.zaneschepke.tunnel.backend.dns.EndpointResolver
|
||||
import com.zaneschepke.tunnel.event.TunnelEvent
|
||||
import com.zaneschepke.tunnel.model.BackendMode
|
||||
import com.zaneschepke.tunnel.model.DnsBoostrapConfig
|
||||
import com.zaneschepke.tunnel.model.DnsBoostrapMode
|
||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||
import com.zaneschepke.tunnel.model.Host
|
||||
@@ -642,7 +641,7 @@ class TunnelBackend(
|
||||
val updatedActiveTunnel = _status.value.activeTunnels[tunnelId] ?: return
|
||||
val tunnel = updatedActiveTunnel.tunnel ?: return
|
||||
|
||||
var results = endpointResolver.resolvePeers(mode)
|
||||
val results = endpointResolver.resolvePeers(mode)
|
||||
if (results.isEmpty()) return
|
||||
|
||||
val networkHasIpv6 = stableNetworkEngine.stableState.value?.state?.hasIpv6 == true
|
||||
@@ -659,23 +658,7 @@ class TunnelBackend(
|
||||
return
|
||||
} ?: return
|
||||
|
||||
var mismatches = findEndpointMismatches(results, activeConfig, preferIpv6)
|
||||
|
||||
if (
|
||||
reason == PeerUpdateReason.DDNS_CHECK && (results.isEmpty() || mismatches.isEmpty())
|
||||
) {
|
||||
Timber.w(
|
||||
"DNS resolution returned no new data, could be stale of cached. Switching to default DoH to avoid cache"
|
||||
)
|
||||
|
||||
val avoidCacheMode =
|
||||
DnsBoostrapMode.Custom(
|
||||
DnsBoostrapConfig.DoH(DnsBoostrapConfig.DEFAULT_DOH_UPSTREAM)
|
||||
)
|
||||
|
||||
results = endpointResolver.resolvePeers(mode, avoidCacheMode)
|
||||
mismatches = findEndpointMismatches(results, activeConfig, preferIpv6)
|
||||
}
|
||||
val mismatches = findEndpointMismatches(results, activeConfig, preferIpv6)
|
||||
|
||||
Timber.d("Reconciliation complete for $reason. Mismatches found: ${mismatches.size}")
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.zaneschepke.tunnel.backend.dns
|
||||
|
||||
import android.net.Network
|
||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||
import java.net.UnknownHostException
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import timber.log.Timber
|
||||
@@ -10,14 +11,19 @@ internal class AndroidNetworkResolver(private val network: Network) : PeerResolv
|
||||
|
||||
override suspend fun resolve(host: String): DnsBootstrapResult =
|
||||
withContext(Dispatchers.IO) {
|
||||
// use underlying network for resolution
|
||||
val ips = network.getAllByName(host)
|
||||
try {
|
||||
// use underlying network for resolution
|
||||
val ips = network.getAllByName(host)
|
||||
|
||||
Timber.d("Resolution from network bind socket: ${ips.contentToString()}")
|
||||
Timber.d("Resolution from network bind socket: ${ips.contentToString()}")
|
||||
|
||||
val v4 = ips.filter { it.address.size == 4 }.map { it.hostAddress }
|
||||
val v6 = ips.filter { it.address.size == 16 }.map { it.hostAddress }
|
||||
val v4 = ips.filter { it.address.size == 4 }.map { it.hostAddress }
|
||||
val v6 = ips.filter { it.address.size == 16 }.map { it.hostAddress }
|
||||
|
||||
DnsBootstrapResult(v4, v6)
|
||||
DnsBootstrapResult(v4, v6)
|
||||
} catch (e: UnknownHostException) {
|
||||
Timber.e(e, "System DNS failed to resolve host")
|
||||
DnsBootstrapResult()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,58 @@
|
||||
package com.zaneschepke.tunnel.backend.dns
|
||||
|
||||
import com.zaneschepke.tunnel.DnsConfigManager
|
||||
import android.net.Network
|
||||
import com.zaneschepke.tunnel.model.DnsBoostrapConfig
|
||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||
import com.zaneschepke.tunnel.util.DnsHostUtils
|
||||
import timber.log.Timber
|
||||
|
||||
class CustomDnsResolver(private val dnsConfig: DnsBoostrapConfig, private val bypass: Boolean) :
|
||||
PeerResolver {
|
||||
class CustomDnsResolver(
|
||||
private val dnsConfig: DnsBoostrapConfig,
|
||||
private val bypass: Boolean,
|
||||
network: Network,
|
||||
) : PeerResolver {
|
||||
|
||||
private val systemResolver = AndroidNetworkResolver(network)
|
||||
|
||||
override suspend fun resolve(host: String): DnsBootstrapResult {
|
||||
return DnsConfigManager.resolveHostBootstrap(
|
||||
host = host,
|
||||
protocol = dnsConfig.protocol,
|
||||
upstream = dnsConfig.upstream ?: DnsBoostrapConfig.DEFAULT_PLAIN_UPSTREAM,
|
||||
bypass = bypass,
|
||||
)
|
||||
|
||||
val upstream = dnsConfig.upstream
|
||||
if (upstream.isNullOrBlank()) {
|
||||
Timber.w("Custom DNS mode selected but no upstream configured")
|
||||
return DnsBootstrapResult()
|
||||
}
|
||||
|
||||
val resolvedUpstream =
|
||||
if (DnsHostUtils.needsResolution(upstream)) {
|
||||
Timber.d("Upstream DNS needs resolution, resolving via system resolver")
|
||||
val hostToResolve = DnsHostUtils.extractHost(upstream)
|
||||
|
||||
val resolutionResult = systemResolver.resolve(hostToResolve)
|
||||
|
||||
val ip = resolutionResult.ipv4.firstOrNull() ?: resolutionResult.ipv6.firstOrNull()
|
||||
if (ip == null) {
|
||||
Timber.w("Failed to resolve custom DNS upstream host: $upstream")
|
||||
return DnsBootstrapResult()
|
||||
}
|
||||
|
||||
DnsHostUtils.replaceHostWithIP(upstream, ip)
|
||||
} else {
|
||||
upstream
|
||||
}
|
||||
|
||||
Timber.d("Using custom resolver with resolved upstream $resolvedUpstream")
|
||||
|
||||
return try {
|
||||
NativeDnsResolver.resolveHostBootstrap(
|
||||
host = host,
|
||||
protocol = dnsConfig.protocol,
|
||||
resolvedUpstream = resolvedUpstream,
|
||||
originalUpstream = upstream,
|
||||
bypass = bypass,
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
Timber.w(e, "Custom DNS resolution failed for host=$host upstream=$resolvedUpstream")
|
||||
DnsBootstrapResult()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
package com.zaneschepke.tunnel.backend.dns
|
||||
|
||||
import android.net.Network
|
||||
import com.zaneschepke.networkmonitor.ConnectivityState
|
||||
import com.zaneschepke.networkmonitor.PrivateDnsMode
|
||||
import com.zaneschepke.networkmonitor.StableNetworkEngine
|
||||
import com.zaneschepke.tunnel.model.BackendMode
|
||||
import com.zaneschepke.tunnel.model.DnsBoostrapConfig
|
||||
import com.zaneschepke.tunnel.model.DnsBoostrapMode
|
||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||
import com.zaneschepke.tunnel.model.PublicKey
|
||||
@@ -21,141 +17,62 @@ class EndpointResolver(
|
||||
private val getDnsMode: () -> DnsBoostrapMode,
|
||||
private val isKillSwitchEnabled: () -> Boolean,
|
||||
) {
|
||||
suspend fun resolvePeers(
|
||||
mode: BackendMode,
|
||||
forceDnsMode: DnsBoostrapMode? = null,
|
||||
): Map<PublicKey, DnsBootstrapResult> = coroutineScope {
|
||||
val peersToResolve = mode.config.peers.filter { !it.isStaticallyConfigured }
|
||||
if (peersToResolve.isEmpty()) return@coroutineScope emptyMap()
|
||||
suspend fun resolvePeers(mode: BackendMode): Map<PublicKey, DnsBootstrapResult> =
|
||||
coroutineScope {
|
||||
val peersToResolve = mode.config.peers.filter { !it.isStaticallyConfigured }
|
||||
if (peersToResolve.isEmpty()) return@coroutineScope emptyMap()
|
||||
|
||||
val results = mutableMapOf<PublicKey, DnsBootstrapResult>()
|
||||
stableNetworkEngine.stableState.first { it?.state?.activeNetwork?.network != null }
|
||||
val results = mutableMapOf<PublicKey, DnsBootstrapResult>()
|
||||
stableNetworkEngine.stableState.first { it?.state?.activeNetwork?.network != null }
|
||||
|
||||
var delayMs = 500L
|
||||
var delayMs = 500L
|
||||
|
||||
while (isActive) {
|
||||
val snapshot = stableNetworkEngine.stableState.value?.state
|
||||
val network =
|
||||
snapshot?.activeNetwork?.network
|
||||
?: run {
|
||||
delay(100.milliseconds)
|
||||
continue
|
||||
}
|
||||
|
||||
val dnsMode = forceDnsMode ?: getDnsMode()
|
||||
val bypassNeeded = mode is BackendMode.Vpn || isKillSwitchEnabled()
|
||||
var progressed = false
|
||||
|
||||
for (peer in peersToResolve) {
|
||||
if (results.containsKey(peer.publicKey)) continue
|
||||
val host = peer.endpoint?.substringBeforeLast(":") ?: continue
|
||||
|
||||
val dnsResult =
|
||||
when (dnsMode) {
|
||||
is DnsBoostrapMode.Custom -> {
|
||||
resolveWithCustomConfig(dnsMode.config, host, bypassNeeded)
|
||||
while (isActive) {
|
||||
val snapshot = stableNetworkEngine.stableState.value?.state
|
||||
val network =
|
||||
snapshot?.activeNetwork?.network
|
||||
?: run {
|
||||
delay(100.milliseconds)
|
||||
continue
|
||||
}
|
||||
is DnsBoostrapMode.System -> {
|
||||
resolveWithSystemStrategy(snapshot, network, host, bypassNeeded)
|
||||
|
||||
val dnsMode = getDnsMode()
|
||||
val bypassNeeded = mode is BackendMode.Vpn || isKillSwitchEnabled()
|
||||
var progressed = false
|
||||
|
||||
for (peer in peersToResolve) {
|
||||
if (results.containsKey(peer.publicKey)) continue
|
||||
val host = peer.endpoint?.substringBeforeLast(":") ?: continue
|
||||
|
||||
val resolver: PeerResolver =
|
||||
when (dnsMode) {
|
||||
is DnsBoostrapMode.System -> AndroidNetworkResolver(network)
|
||||
is DnsBoostrapMode.Custom ->
|
||||
CustomDnsResolver(dnsMode.config, bypassNeeded, network)
|
||||
}
|
||||
|
||||
val result = resolver.resolve(host)
|
||||
|
||||
if (result.ipv4.isNotEmpty() || result.ipv6.isNotEmpty()) {
|
||||
results[peer.publicKey] = result.copy(ipv6 = result.ipv6.map { "[$it]" })
|
||||
progressed = true
|
||||
}
|
||||
|
||||
if (
|
||||
dnsResult != null &&
|
||||
(dnsResult.ipv4.isNotEmpty() || dnsResult.ipv6.isNotEmpty())
|
||||
) {
|
||||
results[peer.publicKey] = dnsResult.copy(ipv6 = dnsResult.ipv6.map { "[$it]" })
|
||||
progressed = true
|
||||
}
|
||||
}
|
||||
|
||||
if (results.keys.containsAll(peersToResolve.map { it.publicKey })) {
|
||||
Timber.d("All peers resolved")
|
||||
return@coroutineScope results
|
||||
}
|
||||
|
||||
if (!progressed) {
|
||||
delay(delayMs.milliseconds)
|
||||
delayMs = (delayMs * 2).coerceAtMost(MAX_BACKOFF)
|
||||
} else {
|
||||
delayMs = 500L // reset after we have progressed
|
||||
}
|
||||
}
|
||||
return@coroutineScope results
|
||||
}
|
||||
|
||||
private suspend fun resolveWithSystemStrategy(
|
||||
snapshot: ConnectivityState,
|
||||
network: Network,
|
||||
host: String,
|
||||
bypass: Boolean,
|
||||
): DnsBootstrapResult? {
|
||||
val dnsInfo = snapshot.underlyingDnsInfo
|
||||
val hasDnsServers = dnsInfo.servers.isNotEmpty()
|
||||
val hasPrivateDnsHostname =
|
||||
dnsInfo.privateDnsMode == PrivateDnsMode.HOSTNAME &&
|
||||
!dnsInfo.privateDnsHostname.isNullOrBlank()
|
||||
|
||||
return when {
|
||||
// Private DNS hostname, use DoT/DoH via custom resolver
|
||||
hasPrivateDnsHostname -> {
|
||||
val hostname = dnsInfo.privateDnsHostname!!
|
||||
val config =
|
||||
DnsBoostrapConfig.SPECIAL_ANDROID_DOH_SERVERS[hostname]?.let {
|
||||
DnsBoostrapConfig.DoH(it)
|
||||
} ?: DnsBoostrapConfig.DoT(hostname)
|
||||
|
||||
Timber.d("System and Private DNS, using ${config.protocol} for $host")
|
||||
resolveWithCustomConfig(config, host, bypass)
|
||||
}
|
||||
|
||||
// Normal system DNS
|
||||
hasDnsServers -> {
|
||||
try {
|
||||
Timber.d("Using system DNS with network provided DNS servers")
|
||||
AndroidNetworkResolver(network).resolve(host)
|
||||
} catch (e: Exception) {
|
||||
Timber.w(e, "AndroidNetworkResolver failed for $host")
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
// No DNS servers on network, fall back to custom with well known
|
||||
else -> {
|
||||
Timber.d("No DNS servers on network, falling back to public DNS for $host")
|
||||
val publicConfig = DnsBoostrapConfig.Plain(DnsBoostrapConfig.DEFAULT_PLAIN_UPSTREAM)
|
||||
resolveWithCustomConfig(publicConfig, host, bypass)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun resolveWithCustomConfig(
|
||||
config: DnsBoostrapConfig,
|
||||
host: String,
|
||||
bypass: Boolean,
|
||||
): DnsBootstrapResult? {
|
||||
val upstream =
|
||||
config.upstream
|
||||
?: when (config) {
|
||||
is DnsBoostrapConfig.DoH -> DnsBoostrapConfig.DEFAULT_DOH_UPSTREAM
|
||||
is DnsBoostrapConfig.DoT -> DnsBoostrapConfig.DEFAULT_DOT_UPSTREAM
|
||||
is DnsBoostrapConfig.Plain -> DnsBoostrapConfig.DEFAULT_PLAIN_UPSTREAM
|
||||
}
|
||||
|
||||
return try {
|
||||
CustomDnsResolver(config, bypass).resolve(host)
|
||||
} catch (e: Exception) {
|
||||
Timber.w(
|
||||
e,
|
||||
"DNS resolution failed for host=%s protocol=%s upstream=%s bypass=%s",
|
||||
host,
|
||||
config.protocol,
|
||||
upstream,
|
||||
bypass,
|
||||
)
|
||||
null
|
||||
if (results.keys.containsAll(peersToResolve.map { it.publicKey })) {
|
||||
Timber.d("All peers resolved")
|
||||
return@coroutineScope results
|
||||
}
|
||||
|
||||
if (!progressed) {
|
||||
delay(delayMs.milliseconds)
|
||||
delayMs = (delayMs * 2).coerceAtMost(MAX_BACKOFF)
|
||||
} else {
|
||||
delayMs = 500L // reset after we have progressed
|
||||
}
|
||||
}
|
||||
return@coroutineScope results
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val MAX_BACKOFF = 30_000L
|
||||
|
||||
+14
-8
@@ -1,30 +1,36 @@
|
||||
package com.zaneschepke.tunnel
|
||||
package com.zaneschepke.tunnel.backend.dns
|
||||
|
||||
import com.zaneschepke.tunnel.model.DnsBoostrapConfig
|
||||
import com.zaneschepke.tunnel.model.DnsBootstrapResult
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
internal object DnsConfigManager {
|
||||
internal object NativeDnsResolver {
|
||||
|
||||
private external fun resolveBootstrap(
|
||||
host: String,
|
||||
protocol: String,
|
||||
upstream: String,
|
||||
underlyingDnsServers: String,
|
||||
resolvedUpstream: String,
|
||||
originalUpstream: String,
|
||||
bypass: Int,
|
||||
): String
|
||||
|
||||
suspend fun resolveHostBootstrap(
|
||||
host: String,
|
||||
protocol: String,
|
||||
upstream: String,
|
||||
underlyingDnsServers: String = DnsBoostrapConfig.DEFAULT_UNDERLYING_SERVERS,
|
||||
resolvedUpstream: String,
|
||||
originalUpstream: String,
|
||||
bypass: Boolean,
|
||||
): DnsBootstrapResult =
|
||||
withContext(Dispatchers.IO) {
|
||||
val bypassOption = if (bypass) 1 else 0
|
||||
val raw = resolveBootstrap(host, protocol, upstream, underlyingDnsServers, bypassOption)
|
||||
val raw =
|
||||
resolveBootstrap(
|
||||
host = host,
|
||||
protocol = protocol,
|
||||
resolvedUpstream = resolvedUpstream,
|
||||
originalUpstream = originalUpstream,
|
||||
bypass = bypassOption,
|
||||
)
|
||||
|
||||
if (raw.startsWith("ERR|")) {
|
||||
throw RuntimeException(raw.removePrefix("ERR|"))
|
||||
@@ -24,18 +24,6 @@ sealed class DnsBoostrapConfig(open val upstream: String?) {
|
||||
override val protocol: String
|
||||
get() = "dot"
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val DEFAULT_UNDERLYING_SERVERS = "1.1.1.1,8.8.8.8"
|
||||
const val DEFAULT_PLAIN_UPSTREAM = "1.1.1.1"
|
||||
const val DEFAULT_DOH_UPSTREAM = "https://cloudflare-dns.com/dns-query"
|
||||
const val DEFAULT_DOT_UPSTREAM = "one.one.one.one"
|
||||
val SPECIAL_ANDROID_DOH_SERVERS =
|
||||
mapOf(
|
||||
"cloudflare-dns.com" to "https://cloudflare-dns.com/dns-query",
|
||||
"dns.google" to "https://dns.google/dns-query",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
data class DnsBootstrapResult(
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package com.zaneschepke.tunnel.util
|
||||
|
||||
import inet.ipaddr.IPAddressString
|
||||
import java.net.URI
|
||||
|
||||
object DnsHostUtils {
|
||||
|
||||
/** Extracts the host portion from a DoH/DoT/Plain upstream string. */
|
||||
fun extractHost(upstream: String): String {
|
||||
val trimmed = upstream.trim()
|
||||
|
||||
// DoH full url
|
||||
if (trimmed.startsWith("http://") || trimmed.startsWith("https://")) {
|
||||
return try {
|
||||
URI(trimmed).host ?: trimmed
|
||||
} catch (_: Exception) {
|
||||
trimmed
|
||||
}
|
||||
}
|
||||
|
||||
val hostPart = trimmed.substringBeforeLast(":")
|
||||
return hostPart.removeSurrounding("[", "]")
|
||||
}
|
||||
|
||||
/** Replaces the hostname in the upstream string with the given IP address. */
|
||||
fun replaceHostWithIP(upstream: String, newIp: String): String {
|
||||
val trimmed = upstream.trim()
|
||||
|
||||
val cleanedIp = newIp.trim().removeSurrounding("[", "]")
|
||||
val isIpv6 = isIpAddress(cleanedIp) && cleanedIp.contains(":")
|
||||
|
||||
val replacementIp = if (isIpv6) "[$cleanedIp]" else cleanedIp
|
||||
|
||||
// handle full url for DoH
|
||||
if (trimmed.startsWith("http://") || trimmed.startsWith("https://")) {
|
||||
return try {
|
||||
val uri = URI(trimmed)
|
||||
val newAuthority =
|
||||
if (uri.port != -1) {
|
||||
"$replacementIp:${uri.port}"
|
||||
} else {
|
||||
replacementIp
|
||||
}
|
||||
|
||||
URI(uri.scheme, newAuthority, uri.path, uri.query, uri.fragment).toString()
|
||||
} catch (_: Exception) {
|
||||
// ust return the IP if URL parsing fails
|
||||
replacementIp
|
||||
}
|
||||
}
|
||||
|
||||
// host:port format DoT and plain
|
||||
if (trimmed.contains(":")) {
|
||||
val port = trimmed.substringAfterLast(":")
|
||||
// Only treat as port if it's numeric
|
||||
if (port.toIntOrNull() != null) {
|
||||
return "$replacementIp:$port"
|
||||
}
|
||||
}
|
||||
|
||||
// bare hostname/ip
|
||||
return replacementIp
|
||||
}
|
||||
|
||||
fun isIpAddress(host: String): Boolean {
|
||||
val cleaned = host.trim().removeSurrounding("[", "]")
|
||||
return try {
|
||||
val addr = IPAddressString(cleaned).address
|
||||
addr != null && (addr.isIPv4 || addr.isIPv6)
|
||||
} catch (_: Exception) {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fun needsResolution(upstream: String): Boolean {
|
||||
val host = extractHost(upstream)
|
||||
return host.isNotBlank() && !isIpAddress(host)
|
||||
}
|
||||
}
|
||||
@@ -43,47 +43,46 @@ type Transport interface {
|
||||
func ResolveBootstrap(
|
||||
host *C.char,
|
||||
protocol *C.char,
|
||||
upstream *C.char,
|
||||
underlyingDnsServers *C.char,
|
||||
resolvedUpstream *C.char,
|
||||
originalUpstream *C.char,
|
||||
bypass C.int,
|
||||
) *C.char {
|
||||
|
||||
h := C.GoString(host)
|
||||
p := C.GoString(protocol)
|
||||
u := C.GoString(upstream)
|
||||
underlying := C.GoString(underlyingDnsServers)
|
||||
resolved := C.GoString(resolvedUpstream)
|
||||
original := C.GoString(originalUpstream)
|
||||
bp := bypass == 1
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
shared.LogDebug(
|
||||
"DNS",
|
||||
"ResolveBootstrap called host=%s protocol=%s upstream=%s bypass=%t",
|
||||
h, p, u, bp,
|
||||
)
|
||||
shared.LogDebug("DNS", "ResolveBootstrap called host=%s protocol=%s resolved=%s original=%s bypass=%t",
|
||||
h, p, resolved, original, bp)
|
||||
|
||||
v4, v6, err := Resolve(ctx, h, p, u, bp, underlying)
|
||||
v4, v6, err := Resolve(ctx, h, p, resolved, original, bp)
|
||||
if err != nil {
|
||||
shared.LogError("DNS", "ResolveBootstrap failed for %s: %v", h, err)
|
||||
return C.CString("ERR|" + err.Error())
|
||||
}
|
||||
|
||||
v4Str := make([]string, len(v4))
|
||||
for i, ip := range v4 {
|
||||
v4Str[i] = ip.String()
|
||||
}
|
||||
v6Str := make([]string, len(v6))
|
||||
for i, ip := range v6 {
|
||||
v6Str[i] = ip.String()
|
||||
}
|
||||
|
||||
result := "v4=" + strings.Join(v4Str, ",") +
|
||||
";v6=" + strings.Join(v6Str, ",")
|
||||
result := fmt.Sprintf("v4=%s;v6=%s",
|
||||
strings.Join(toStringSlice(v4), ","),
|
||||
strings.Join(toStringSlice(v6), ","),
|
||||
)
|
||||
|
||||
shared.LogDebug("DNS", "ResolveBootstrap success for %s: %s", h, result)
|
||||
return C.CString(result)
|
||||
}
|
||||
|
||||
func toStringSlice(addrs []netip.Addr) []string {
|
||||
out := make([]string, len(addrs))
|
||||
for i, a := range addrs {
|
||||
out[i] = a.String()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type DoTTransport struct {
|
||||
Client *dns.Client
|
||||
Servers []string
|
||||
@@ -264,20 +263,26 @@ func resolveServerAddrs(
|
||||
|
||||
func (t PlainTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
for _, server := range t.Servers {
|
||||
m, _, err := t.Client.Exchange(msg, server)
|
||||
m, _, err := t.Client.ExchangeContext(ctx, msg, server)
|
||||
if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
|
||||
return m, nil
|
||||
}
|
||||
if err != nil {
|
||||
shared.LogDebug("DNS", "Plain DNS query to %s failed: %v", server, err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("all DNS servers failed")
|
||||
}
|
||||
|
||||
func (t DoTTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
for _, server := range t.Servers {
|
||||
m, _, err := t.Client.Exchange(msg, server)
|
||||
m, _, err := t.Client.ExchangeContext(ctx, msg, server)
|
||||
if err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
|
||||
return m, nil
|
||||
}
|
||||
if err != nil {
|
||||
shared.LogDebug("DNS", "DoT Exchange to %s failed: %v", server, err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("all DoT servers failed")
|
||||
}
|
||||
@@ -343,11 +348,11 @@ func parseDNSAnswers(msg *dns.Msg, qtype uint16) []netip.Addr {
|
||||
|
||||
func Resolve(
|
||||
ctx context.Context,
|
||||
host, protocol, upstream string,
|
||||
host, protocol, resolvedUpstream, originalUpstream string,
|
||||
bypass bool,
|
||||
underlying string,
|
||||
) ([]netip.Addr, []netip.Addr, error) {
|
||||
t, err := buildTransport(ctx, protocol, upstream, bypass, underlying)
|
||||
|
||||
t, err := buildTransport(protocol, resolvedUpstream, originalUpstream, bypass)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -355,88 +360,89 @@ func Resolve(
|
||||
}
|
||||
|
||||
func buildTransport(
|
||||
ctx context.Context,
|
||||
protocol, upstream string,
|
||||
protocol, resolvedUpstream, originalUpstream string,
|
||||
bypass bool,
|
||||
underlying string,
|
||||
) (Transport, error) {
|
||||
|
||||
switch protocol {
|
||||
case "doh":
|
||||
u, err := url.Parse(upstream)
|
||||
// Parse original for SNI
|
||||
origURL, err := url.Parse(originalUpstream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("invalid original DoH upstream: %w", err)
|
||||
}
|
||||
hostname := u.Hostname()
|
||||
port := u.Port()
|
||||
|
||||
originalHost := origURL.Hostname()
|
||||
|
||||
// Parse resolved to get the IP
|
||||
resolvedURL, _ := url.Parse(resolvedUpstream)
|
||||
dialHost := resolvedURL.Hostname()
|
||||
if dialHost == "" {
|
||||
dialHost = originalHost // fallback
|
||||
}
|
||||
|
||||
port := origURL.Port()
|
||||
if port == "" {
|
||||
port = "443"
|
||||
}
|
||||
u.Host = net.JoinHostPort(hostname, port)
|
||||
|
||||
// Pre-resolve with IPv4-first ordering + bypass
|
||||
servers, _, err := resolveServerAddrs(ctx, u.Host, bypass, "443", underlying)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(servers) == 0 {
|
||||
return nil, fmt.Errorf("no addresses resolved for DoH server")
|
||||
}
|
||||
|
||||
// Custom dialer that tries servers in order
|
||||
// tries ipv4 first and then ipv6
|
||||
dialer := GetDialer(bypass)
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
for _, addr := range servers {
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("all DoH addresses failed")
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(dialHost, port))
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
ServerName: hostname,
|
||||
ServerName: originalHost, // Use original hostname for certificate validation
|
||||
},
|
||||
}
|
||||
|
||||
finalURL := origURL.String()
|
||||
if !strings.HasPrefix(finalURL, "https://") {
|
||||
finalURL = "https://" + finalURL
|
||||
}
|
||||
|
||||
return DoHTransport{
|
||||
Client: &http.Client{Timeout: 5 * time.Second, Transport: transport},
|
||||
URL: u.String(),
|
||||
Servers: servers,
|
||||
Hostname: hostname,
|
||||
URL: finalURL,
|
||||
Hostname: originalHost,
|
||||
}, nil
|
||||
|
||||
case "dot":
|
||||
servers, sni, err := resolveServerAddrs(ctx, upstream, bypass, "853", underlying)
|
||||
// Get SNI from original
|
||||
origHost, origPort, err := net.SplitHostPort(originalUpstream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
origHost = originalUpstream
|
||||
origPort = "853"
|
||||
}
|
||||
if len(servers) == 0 {
|
||||
return nil, fmt.Errorf("no addresses resolved for DoT server")
|
||||
|
||||
// Get connection target from resolved
|
||||
resolvedHost, resolvedPort, _ := net.SplitHostPort(resolvedUpstream)
|
||||
if resolvedHost == "" {
|
||||
resolvedHost = resolvedUpstream
|
||||
resolvedPort = origPort
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Dialer: GetDialer(bypass),
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: 6 * time.Second,
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: sni,
|
||||
ServerName: origHost,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
|
||||
return DoTTransport{
|
||||
Client: client,
|
||||
Servers: servers,
|
||||
Servers: []string{net.JoinHostPort(resolvedHost, resolvedPort)},
|
||||
}, nil
|
||||
|
||||
default: // plain DNS
|
||||
_, addr, err := parseUpstream(upstream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers, _, err := resolveServerAddrs(ctx, addr, bypass, "53", underlying)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
default: // plain
|
||||
host, port, _ := net.SplitHostPort(resolvedUpstream)
|
||||
if host == "" {
|
||||
host = resolvedUpstream
|
||||
port = "53"
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
@@ -446,7 +452,7 @@ func buildTransport(
|
||||
}
|
||||
return PlainTransport{
|
||||
Client: client,
|
||||
Servers: servers,
|
||||
Servers: []string{net.JoinHostPort(host, port)},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,45 +6,45 @@ struct go_string { const char *str; long n; };
|
||||
extern char* ResolveBootstrap(
|
||||
const char* host,
|
||||
const char* protocol,
|
||||
const char* upstream,
|
||||
const char* underlyingDnsServers,
|
||||
const char* resolvedUpstream,
|
||||
const char* originalUpstream,
|
||||
int bypass);
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_zaneschepke_tunnel_DnsConfigManager_resolveBootstrap(
|
||||
Java_com_zaneschepke_tunnel_backend_dns_NativeDnsResolver_resolveBootstrap(
|
||||
JNIEnv* env,
|
||||
jclass clazz,
|
||||
jstring host,
|
||||
jstring protocol,
|
||||
jstring upstream,
|
||||
jstring underlyingDnsServers,
|
||||
jstring resolvedUpstream,
|
||||
jstring originalUpstream,
|
||||
jint bypass)
|
||||
{
|
||||
if (host == NULL || protocol == NULL || upstream == NULL || underlyingDnsServers == NULL) {
|
||||
if (host == NULL || protocol == NULL || resolvedUpstream == NULL || originalUpstream == NULL) {
|
||||
return (*env)->NewStringUTF(env, "ERR|invalid arguments");
|
||||
}
|
||||
|
||||
const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
|
||||
const char* cprotocol = (*env)->GetStringUTFChars(env, protocol, NULL);
|
||||
const char* cupstream = (*env)->GetStringUTFChars(env, upstream, NULL);
|
||||
const char* cunderlying = (*env)->GetStringUTFChars(env, underlyingDnsServers, NULL);
|
||||
const char* chost = (*env)->GetStringUTFChars(env, host, NULL);
|
||||
const char* cprotocol = (*env)->GetStringUTFChars(env, protocol, NULL);
|
||||
const char* cresolvedUpstream = (*env)->GetStringUTFChars(env, resolvedUpstream, NULL);
|
||||
const char* coriginalUpstream = (*env)->GetStringUTFChars(env, originalUpstream, NULL);
|
||||
|
||||
if (chost == NULL || cprotocol == NULL || cupstream == NULL || cunderlying == NULL) {
|
||||
if (chost == NULL || cprotocol == NULL || cresolvedUpstream == NULL || coriginalUpstream == NULL) {
|
||||
return (*env)->NewStringUTF(env, "ERR|out of memory");
|
||||
}
|
||||
|
||||
char* resultC = ResolveBootstrap(
|
||||
chost,
|
||||
cprotocol,
|
||||
cupstream,
|
||||
cunderlying,
|
||||
cresolvedUpstream,
|
||||
coriginalUpstream,
|
||||
bypass ? 1 : 0
|
||||
);
|
||||
|
||||
(*env)->ReleaseStringUTFChars(env, host, chost);
|
||||
(*env)->ReleaseStringUTFChars(env, protocol, cprotocol);
|
||||
(*env)->ReleaseStringUTFChars(env, upstream, cupstream);
|
||||
(*env)->ReleaseStringUTFChars(env, underlyingDnsServers, cunderlying);
|
||||
(*env)->ReleaseStringUTFChars(env, resolvedUpstream, cresolvedUpstream);
|
||||
(*env)->ReleaseStringUTFChars(env, originalUpstream, coriginalUpstream);
|
||||
|
||||
if (resultC == NULL) {
|
||||
return (*env)->NewStringUTF(env, "ERR|null response");
|
||||
|
||||
Reference in New Issue
Block a user