diff --git a/app/src/main/java/com/zaneschepke/wireguardautotunnel/util/DnsValidator.kt b/app/src/main/java/com/zaneschepke/wireguardautotunnel/util/DnsValidator.kt index ed605e6e..1daa0f7a 100644 --- a/app/src/main/java/com/zaneschepke/wireguardautotunnel/util/DnsValidator.kt +++ b/app/src/main/java/com/zaneschepke/wireguardautotunnel/util/DnsValidator.kt @@ -81,7 +81,7 @@ object DnsValidator { return Result.Valid } - private fun validateUdp(value: String): DnsValidator.Result { + private fun validateUdp(value: String): Result { val parts = value.split(":") val host = parts.getOrNull(0)?.trim() @@ -93,14 +93,14 @@ object DnsValidator { // basic IP/hostname sanity check if (!isValidHostOrIp(host)) { - return DnsValidator.Result.Invalid(DnsError.InvalidIpOrHost) + return Result.Invalid(DnsError.InvalidIpOrHost) } if (port !in 1..65535) { - return DnsValidator.Result.Invalid(DnsError.InvalidPort) + return Result.Invalid(DnsError.InvalidPort) } - return DnsValidator.Result.Valid + return Result.Valid } private fun isValidHostOrIp(value: String): Boolean { diff --git a/tunnel/build.gradle.kts b/tunnel/build.gradle.kts index 57f28654..64b0e269 100644 --- a/tunnel/build.gradle.kts +++ b/tunnel/build.gradle.kts @@ -66,6 +66,7 @@ dependencies { api(libs.amneziawg.parser) implementation(libs.libsu) + implementation(libs.ipaddress) implementation(libs.timber) diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/TunnelBackend.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/TunnelBackend.kt index a775ba2b..f9c9c256 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/TunnelBackend.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/TunnelBackend.kt @@ -8,7 +8,6 @@ import com.zaneschepke.tunnel.VpnBackend import com.zaneschepke.tunnel.backend.dns.EndpointResolver import com.zaneschepke.tunnel.event.TunnelEvent import com.zaneschepke.tunnel.model.BackendMode -import com.zaneschepke.tunnel.model.DnsBoostrapConfig import com.zaneschepke.tunnel.model.DnsBoostrapMode import com.zaneschepke.tunnel.model.DnsBootstrapResult import com.zaneschepke.tunnel.model.Host @@ -642,7 +641,7 @@ class TunnelBackend( val updatedActiveTunnel = _status.value.activeTunnels[tunnelId] ?: return val tunnel = updatedActiveTunnel.tunnel ?: return - var results = endpointResolver.resolvePeers(mode) + val results = endpointResolver.resolvePeers(mode) if (results.isEmpty()) return val networkHasIpv6 = stableNetworkEngine.stableState.value?.state?.hasIpv6 == true @@ -659,23 +658,7 @@ class TunnelBackend( return } ?: return - var mismatches = findEndpointMismatches(results, activeConfig, preferIpv6) - - if ( - reason == PeerUpdateReason.DDNS_CHECK && (results.isEmpty() || mismatches.isEmpty()) - ) { - Timber.w( - "DNS resolution returned no new data, could be stale of cached. Switching to default DoH to avoid cache" - ) - - val avoidCacheMode = - DnsBoostrapMode.Custom( - DnsBoostrapConfig.DoH(DnsBoostrapConfig.DEFAULT_DOH_UPSTREAM) - ) - - results = endpointResolver.resolvePeers(mode, avoidCacheMode) - mismatches = findEndpointMismatches(results, activeConfig, preferIpv6) - } + val mismatches = findEndpointMismatches(results, activeConfig, preferIpv6) Timber.d("Reconciliation complete for $reason. Mismatches found: ${mismatches.size}") diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/AndroidNetworkResolver.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/AndroidNetworkResolver.kt index bfc660bb..b8b0f06c 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/AndroidNetworkResolver.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/AndroidNetworkResolver.kt @@ -2,6 +2,7 @@ package com.zaneschepke.tunnel.backend.dns import android.net.Network import com.zaneschepke.tunnel.model.DnsBootstrapResult +import java.net.UnknownHostException import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import timber.log.Timber @@ -10,14 +11,19 @@ internal class AndroidNetworkResolver(private val network: Network) : PeerResolv override suspend fun resolve(host: String): DnsBootstrapResult = withContext(Dispatchers.IO) { - // use underlying network for resolution - val ips = network.getAllByName(host) + try { + // use underlying network for resolution + val ips = network.getAllByName(host) - Timber.d("Resolution from network bind socket: ${ips.contentToString()}") + Timber.d("Resolution from network bind socket: ${ips.contentToString()}") - val v4 = ips.filter { it.address.size == 4 }.map { it.hostAddress } - val v6 = ips.filter { it.address.size == 16 }.map { it.hostAddress } + val v4 = ips.filter { it.address.size == 4 }.map { it.hostAddress } + val v6 = ips.filter { it.address.size == 16 }.map { it.hostAddress } - DnsBootstrapResult(v4, v6) + DnsBootstrapResult(v4, v6) + } catch (e: UnknownHostException) { + Timber.e(e, "System DNS failed to resolve host") + DnsBootstrapResult() + } } } diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/CustomDnsResolver.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/CustomDnsResolver.kt index 8554ba1a..90d9474f 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/CustomDnsResolver.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/CustomDnsResolver.kt @@ -1,18 +1,58 @@ package com.zaneschepke.tunnel.backend.dns -import com.zaneschepke.tunnel.DnsConfigManager +import android.net.Network import com.zaneschepke.tunnel.model.DnsBoostrapConfig import com.zaneschepke.tunnel.model.DnsBootstrapResult +import com.zaneschepke.tunnel.util.DnsHostUtils +import timber.log.Timber -class CustomDnsResolver(private val dnsConfig: DnsBoostrapConfig, private val bypass: Boolean) : - PeerResolver { +class CustomDnsResolver( + private val dnsConfig: DnsBoostrapConfig, + private val bypass: Boolean, + network: Network, +) : PeerResolver { + + private val systemResolver = AndroidNetworkResolver(network) override suspend fun resolve(host: String): DnsBootstrapResult { - return DnsConfigManager.resolveHostBootstrap( - host = host, - protocol = dnsConfig.protocol, - upstream = dnsConfig.upstream ?: DnsBoostrapConfig.DEFAULT_PLAIN_UPSTREAM, - bypass = bypass, - ) + + val upstream = dnsConfig.upstream + if (upstream.isNullOrBlank()) { + Timber.w("Custom DNS mode selected but no upstream configured") + return DnsBootstrapResult() + } + + val resolvedUpstream = + if (DnsHostUtils.needsResolution(upstream)) { + Timber.d("Upstream DNS needs resolution, resolving via system resolver") + val hostToResolve = DnsHostUtils.extractHost(upstream) + + val resolutionResult = systemResolver.resolve(hostToResolve) + + val ip = resolutionResult.ipv4.firstOrNull() ?: resolutionResult.ipv6.firstOrNull() + if (ip == null) { + Timber.w("Failed to resolve custom DNS upstream host: $upstream") + return DnsBootstrapResult() + } + + DnsHostUtils.replaceHostWithIP(upstream, ip) + } else { + upstream + } + + Timber.d("Using custom resolver with resolved upstream $resolvedUpstream") + + return try { + NativeDnsResolver.resolveHostBootstrap( + host = host, + protocol = dnsConfig.protocol, + resolvedUpstream = resolvedUpstream, + originalUpstream = upstream, + bypass = bypass, + ) + } catch (e: Exception) { + Timber.w(e, "Custom DNS resolution failed for host=$host upstream=$resolvedUpstream") + DnsBootstrapResult() + } } } diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/EndpointResolver.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/EndpointResolver.kt index 898c90a3..bcc752c2 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/EndpointResolver.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/EndpointResolver.kt @@ -1,11 +1,7 @@ package com.zaneschepke.tunnel.backend.dns -import android.net.Network -import com.zaneschepke.networkmonitor.ConnectivityState -import com.zaneschepke.networkmonitor.PrivateDnsMode import com.zaneschepke.networkmonitor.StableNetworkEngine import com.zaneschepke.tunnel.model.BackendMode -import com.zaneschepke.tunnel.model.DnsBoostrapConfig import com.zaneschepke.tunnel.model.DnsBoostrapMode import com.zaneschepke.tunnel.model.DnsBootstrapResult import com.zaneschepke.tunnel.model.PublicKey @@ -21,141 +17,62 @@ class EndpointResolver( private val getDnsMode: () -> DnsBoostrapMode, private val isKillSwitchEnabled: () -> Boolean, ) { - suspend fun resolvePeers( - mode: BackendMode, - forceDnsMode: DnsBoostrapMode? = null, - ): Map = coroutineScope { - val peersToResolve = mode.config.peers.filter { !it.isStaticallyConfigured } - if (peersToResolve.isEmpty()) return@coroutineScope emptyMap() + suspend fun resolvePeers(mode: BackendMode): Map = + coroutineScope { + val peersToResolve = mode.config.peers.filter { !it.isStaticallyConfigured } + if (peersToResolve.isEmpty()) return@coroutineScope emptyMap() - val results = mutableMapOf() - stableNetworkEngine.stableState.first { it?.state?.activeNetwork?.network != null } + val results = mutableMapOf() + stableNetworkEngine.stableState.first { it?.state?.activeNetwork?.network != null } - var delayMs = 500L + var delayMs = 500L - while (isActive) { - val snapshot = stableNetworkEngine.stableState.value?.state - val network = - snapshot?.activeNetwork?.network - ?: run { - delay(100.milliseconds) - continue - } - - val dnsMode = forceDnsMode ?: getDnsMode() - val bypassNeeded = mode is BackendMode.Vpn || isKillSwitchEnabled() - var progressed = false - - for (peer in peersToResolve) { - if (results.containsKey(peer.publicKey)) continue - val host = peer.endpoint?.substringBeforeLast(":") ?: continue - - val dnsResult = - when (dnsMode) { - is DnsBoostrapMode.Custom -> { - resolveWithCustomConfig(dnsMode.config, host, bypassNeeded) + while (isActive) { + val snapshot = stableNetworkEngine.stableState.value?.state + val network = + snapshot?.activeNetwork?.network + ?: run { + delay(100.milliseconds) + continue } - is DnsBoostrapMode.System -> { - resolveWithSystemStrategy(snapshot, network, host, bypassNeeded) + + val dnsMode = getDnsMode() + val bypassNeeded = mode is BackendMode.Vpn || isKillSwitchEnabled() + var progressed = false + + for (peer in peersToResolve) { + if (results.containsKey(peer.publicKey)) continue + val host = peer.endpoint?.substringBeforeLast(":") ?: continue + + val resolver: PeerResolver = + when (dnsMode) { + is DnsBoostrapMode.System -> AndroidNetworkResolver(network) + is DnsBoostrapMode.Custom -> + CustomDnsResolver(dnsMode.config, bypassNeeded, network) } + + val result = resolver.resolve(host) + + if (result.ipv4.isNotEmpty() || result.ipv6.isNotEmpty()) { + results[peer.publicKey] = result.copy(ipv6 = result.ipv6.map { "[$it]" }) + progressed = true } - - if ( - dnsResult != null && - (dnsResult.ipv4.isNotEmpty() || dnsResult.ipv6.isNotEmpty()) - ) { - results[peer.publicKey] = dnsResult.copy(ipv6 = dnsResult.ipv6.map { "[$it]" }) - progressed = true - } - } - - if (results.keys.containsAll(peersToResolve.map { it.publicKey })) { - Timber.d("All peers resolved") - return@coroutineScope results - } - - if (!progressed) { - delay(delayMs.milliseconds) - delayMs = (delayMs * 2).coerceAtMost(MAX_BACKOFF) - } else { - delayMs = 500L // reset after we have progressed - } - } - return@coroutineScope results - } - - private suspend fun resolveWithSystemStrategy( - snapshot: ConnectivityState, - network: Network, - host: String, - bypass: Boolean, - ): DnsBootstrapResult? { - val dnsInfo = snapshot.underlyingDnsInfo - val hasDnsServers = dnsInfo.servers.isNotEmpty() - val hasPrivateDnsHostname = - dnsInfo.privateDnsMode == PrivateDnsMode.HOSTNAME && - !dnsInfo.privateDnsHostname.isNullOrBlank() - - return when { - // Private DNS hostname, use DoT/DoH via custom resolver - hasPrivateDnsHostname -> { - val hostname = dnsInfo.privateDnsHostname!! - val config = - DnsBoostrapConfig.SPECIAL_ANDROID_DOH_SERVERS[hostname]?.let { - DnsBoostrapConfig.DoH(it) - } ?: DnsBoostrapConfig.DoT(hostname) - - Timber.d("System and Private DNS, using ${config.protocol} for $host") - resolveWithCustomConfig(config, host, bypass) - } - - // Normal system DNS - hasDnsServers -> { - try { - Timber.d("Using system DNS with network provided DNS servers") - AndroidNetworkResolver(network).resolve(host) - } catch (e: Exception) { - Timber.w(e, "AndroidNetworkResolver failed for $host") - null - } - } - - // No DNS servers on network, fall back to custom with well known - else -> { - Timber.d("No DNS servers on network, falling back to public DNS for $host") - val publicConfig = DnsBoostrapConfig.Plain(DnsBoostrapConfig.DEFAULT_PLAIN_UPSTREAM) - resolveWithCustomConfig(publicConfig, host, bypass) - } - } - } - - private suspend fun resolveWithCustomConfig( - config: DnsBoostrapConfig, - host: String, - bypass: Boolean, - ): DnsBootstrapResult? { - val upstream = - config.upstream - ?: when (config) { - is DnsBoostrapConfig.DoH -> DnsBoostrapConfig.DEFAULT_DOH_UPSTREAM - is DnsBoostrapConfig.DoT -> DnsBoostrapConfig.DEFAULT_DOT_UPSTREAM - is DnsBoostrapConfig.Plain -> DnsBoostrapConfig.DEFAULT_PLAIN_UPSTREAM } - return try { - CustomDnsResolver(config, bypass).resolve(host) - } catch (e: Exception) { - Timber.w( - e, - "DNS resolution failed for host=%s protocol=%s upstream=%s bypass=%s", - host, - config.protocol, - upstream, - bypass, - ) - null + if (results.keys.containsAll(peersToResolve.map { it.publicKey })) { + Timber.d("All peers resolved") + return@coroutineScope results + } + + if (!progressed) { + delay(delayMs.milliseconds) + delayMs = (delayMs * 2).coerceAtMost(MAX_BACKOFF) + } else { + delayMs = 500L // reset after we have progressed + } + } + return@coroutineScope results } - } companion object { private const val MAX_BACKOFF = 30_000L diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/DnsConfigManager.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/NativeDnsResolver.kt similarity index 69% rename from tunnel/src/main/java/com/zaneschepke/tunnel/DnsConfigManager.kt rename to tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/NativeDnsResolver.kt index e282d35f..fe2c41a7 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/DnsConfigManager.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/backend/dns/NativeDnsResolver.kt @@ -1,30 +1,36 @@ -package com.zaneschepke.tunnel +package com.zaneschepke.tunnel.backend.dns -import com.zaneschepke.tunnel.model.DnsBoostrapConfig import com.zaneschepke.tunnel.model.DnsBootstrapResult import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -internal object DnsConfigManager { +internal object NativeDnsResolver { private external fun resolveBootstrap( host: String, protocol: String, - upstream: String, - underlyingDnsServers: String, + resolvedUpstream: String, + originalUpstream: String, bypass: Int, ): String suspend fun resolveHostBootstrap( host: String, protocol: String, - upstream: String, - underlyingDnsServers: String = DnsBoostrapConfig.DEFAULT_UNDERLYING_SERVERS, + resolvedUpstream: String, + originalUpstream: String, bypass: Boolean, ): DnsBootstrapResult = withContext(Dispatchers.IO) { val bypassOption = if (bypass) 1 else 0 - val raw = resolveBootstrap(host, protocol, upstream, underlyingDnsServers, bypassOption) + val raw = + resolveBootstrap( + host = host, + protocol = protocol, + resolvedUpstream = resolvedUpstream, + originalUpstream = originalUpstream, + bypass = bypassOption, + ) if (raw.startsWith("ERR|")) { throw RuntimeException(raw.removePrefix("ERR|")) diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/model/DnsBootstrap.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/model/DnsBootstrap.kt index 7db95032..2a9750f6 100644 --- a/tunnel/src/main/java/com/zaneschepke/tunnel/model/DnsBootstrap.kt +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/model/DnsBootstrap.kt @@ -24,18 +24,6 @@ sealed class DnsBoostrapConfig(open val upstream: String?) { override val protocol: String get() = "dot" } - - companion object { - const val DEFAULT_UNDERLYING_SERVERS = "1.1.1.1,8.8.8.8" - const val DEFAULT_PLAIN_UPSTREAM = "1.1.1.1" - const val DEFAULT_DOH_UPSTREAM = "https://cloudflare-dns.com/dns-query" - const val DEFAULT_DOT_UPSTREAM = "one.one.one.one" - val SPECIAL_ANDROID_DOH_SERVERS = - mapOf( - "cloudflare-dns.com" to "https://cloudflare-dns.com/dns-query", - "dns.google" to "https://dns.google/dns-query", - ) - } } data class DnsBootstrapResult( diff --git a/tunnel/src/main/java/com/zaneschepke/tunnel/util/DnsHostUtils.kt b/tunnel/src/main/java/com/zaneschepke/tunnel/util/DnsHostUtils.kt new file mode 100644 index 00000000..dca7c03e --- /dev/null +++ b/tunnel/src/main/java/com/zaneschepke/tunnel/util/DnsHostUtils.kt @@ -0,0 +1,79 @@ +package com.zaneschepke.tunnel.util + +import inet.ipaddr.IPAddressString +import java.net.URI + +object DnsHostUtils { + + /** Extracts the host portion from a DoH/DoT/Plain upstream string. */ + fun extractHost(upstream: String): String { + val trimmed = upstream.trim() + + // DoH full url + if (trimmed.startsWith("http://") || trimmed.startsWith("https://")) { + return try { + URI(trimmed).host ?: trimmed + } catch (_: Exception) { + trimmed + } + } + + val hostPart = trimmed.substringBeforeLast(":") + return hostPart.removeSurrounding("[", "]") + } + + /** Replaces the hostname in the upstream string with the given IP address. */ + fun replaceHostWithIP(upstream: String, newIp: String): String { + val trimmed = upstream.trim() + + val cleanedIp = newIp.trim().removeSurrounding("[", "]") + val isIpv6 = isIpAddress(cleanedIp) && cleanedIp.contains(":") + + val replacementIp = if (isIpv6) "[$cleanedIp]" else cleanedIp + + // handle full url for DoH + if (trimmed.startsWith("http://") || trimmed.startsWith("https://")) { + return try { + val uri = URI(trimmed) + val newAuthority = + if (uri.port != -1) { + "$replacementIp:${uri.port}" + } else { + replacementIp + } + + URI(uri.scheme, newAuthority, uri.path, uri.query, uri.fragment).toString() + } catch (_: Exception) { + // ust return the IP if URL parsing fails + replacementIp + } + } + + // host:port format DoT and plain + if (trimmed.contains(":")) { + val port = trimmed.substringAfterLast(":") + // Only treat as port if it's numeric + if (port.toIntOrNull() != null) { + return "$replacementIp:$port" + } + } + + // bare hostname/ip + return replacementIp + } + + fun isIpAddress(host: String): Boolean { + val cleaned = host.trim().removeSurrounding("[", "]") + return try { + val addr = IPAddressString(cleaned).address + addr != null && (addr.isIPv4 || addr.isIPv6) + } catch (_: Exception) { + false + } + } + + fun needsResolution(upstream: String): Boolean { + val host = extractHost(upstream) + return host.isNotBlank() && !isIpAddress(host) + } +} diff --git a/tunnel/tools/libwg-go/dns/dns.go b/tunnel/tools/libwg-go/dns/dns.go index bb56ae28..967e3495 100644 --- a/tunnel/tools/libwg-go/dns/dns.go +++ b/tunnel/tools/libwg-go/dns/dns.go @@ -43,47 +43,46 @@ type Transport interface { func ResolveBootstrap( host *C.char, protocol *C.char, - upstream *C.char, - underlyingDnsServers *C.char, + resolvedUpstream *C.char, + originalUpstream *C.char, bypass C.int, ) *C.char { + h := C.GoString(host) p := C.GoString(protocol) - u := C.GoString(upstream) - underlying := C.GoString(underlyingDnsServers) + resolved := C.GoString(resolvedUpstream) + original := C.GoString(originalUpstream) bp := bypass == 1 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - shared.LogDebug( - "DNS", - "ResolveBootstrap called host=%s protocol=%s upstream=%s bypass=%t", - h, p, u, bp, - ) + shared.LogDebug("DNS", "ResolveBootstrap called host=%s protocol=%s resolved=%s original=%s bypass=%t", + h, p, resolved, original, bp) - v4, v6, err := Resolve(ctx, h, p, u, bp, underlying) + v4, v6, err := Resolve(ctx, h, p, resolved, original, bp) if err != nil { shared.LogError("DNS", "ResolveBootstrap failed for %s: %v", h, err) return C.CString("ERR|" + err.Error()) } - v4Str := make([]string, len(v4)) - for i, ip := range v4 { - v4Str[i] = ip.String() - } - v6Str := make([]string, len(v6)) - for i, ip := range v6 { - v6Str[i] = ip.String() - } - - result := "v4=" + strings.Join(v4Str, ",") + - ";v6=" + strings.Join(v6Str, ",") + result := fmt.Sprintf("v4=%s;v6=%s", + strings.Join(toStringSlice(v4), ","), + strings.Join(toStringSlice(v6), ","), + ) shared.LogDebug("DNS", "ResolveBootstrap success for %s: %s", h, result) return C.CString(result) } +func toStringSlice(addrs []netip.Addr) []string { + out := make([]string, len(addrs)) + for i, a := range addrs { + out[i] = a.String() + } + return out +} + type DoTTransport struct { Client *dns.Client Servers []string @@ -264,20 +263,26 @@ func resolveServerAddrs( func (t PlainTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { for _, server := range t.Servers { - m, _, err := t.Client.Exchange(msg, server) + m, _, err := t.Client.ExchangeContext(ctx, msg, server) if err == nil && m != nil && m.Rcode == dns.RcodeSuccess { return m, nil } + if err != nil { + shared.LogDebug("DNS", "Plain DNS query to %s failed: %v", server, err) + } } return nil, fmt.Errorf("all DNS servers failed") } func (t DoTTransport) Query(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { for _, server := range t.Servers { - m, _, err := t.Client.Exchange(msg, server) + m, _, err := t.Client.ExchangeContext(ctx, msg, server) if err == nil && m != nil && m.Rcode == dns.RcodeSuccess { return m, nil } + if err != nil { + shared.LogDebug("DNS", "DoT Exchange to %s failed: %v", server, err) + } } return nil, fmt.Errorf("all DoT servers failed") } @@ -343,11 +348,11 @@ func parseDNSAnswers(msg *dns.Msg, qtype uint16) []netip.Addr { func Resolve( ctx context.Context, - host, protocol, upstream string, + host, protocol, resolvedUpstream, originalUpstream string, bypass bool, - underlying string, ) ([]netip.Addr, []netip.Addr, error) { - t, err := buildTransport(ctx, protocol, upstream, bypass, underlying) + + t, err := buildTransport(protocol, resolvedUpstream, originalUpstream, bypass) if err != nil { return nil, nil, err } @@ -355,88 +360,89 @@ func Resolve( } func buildTransport( - ctx context.Context, - protocol, upstream string, + protocol, resolvedUpstream, originalUpstream string, bypass bool, - underlying string, ) (Transport, error) { + switch protocol { case "doh": - u, err := url.Parse(upstream) + // Parse original for SNI + origURL, err := url.Parse(originalUpstream) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid original DoH upstream: %w", err) } - hostname := u.Hostname() - port := u.Port() + + originalHost := origURL.Hostname() + + // Parse resolved to get the IP + resolvedURL, _ := url.Parse(resolvedUpstream) + dialHost := resolvedURL.Hostname() + if dialHost == "" { + dialHost = originalHost // fallback + } + + port := origURL.Port() if port == "" { port = "443" } - u.Host = net.JoinHostPort(hostname, port) - // Pre-resolve with IPv4-first ordering + bypass - servers, _, err := resolveServerAddrs(ctx, u.Host, bypass, "443", underlying) - if err != nil { - return nil, err - } - if len(servers) == 0 { - return nil, fmt.Errorf("no addresses resolved for DoH server") - } - - // Custom dialer that tries servers in order - // tries ipv4 first and then ipv6 dialer := GetDialer(bypass) + transport := &http.Transport{ - DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { - for _, addr := range servers { - conn, err := dialer.DialContext(ctx, network, addr) - if err == nil { - return conn, nil - } - } - return nil, fmt.Errorf("all DoH addresses failed") + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, net.JoinHostPort(dialHost, port)) }, TLSClientConfig: &tls.Config{ - ServerName: hostname, + ServerName: originalHost, // Use original hostname for certificate validation }, } + finalURL := origURL.String() + if !strings.HasPrefix(finalURL, "https://") { + finalURL = "https://" + finalURL + } + return DoHTransport{ Client: &http.Client{Timeout: 5 * time.Second, Transport: transport}, - URL: u.String(), - Servers: servers, - Hostname: hostname, + URL: finalURL, + Hostname: originalHost, }, nil case "dot": - servers, sni, err := resolveServerAddrs(ctx, upstream, bypass, "853", underlying) + // Get SNI from original + origHost, origPort, err := net.SplitHostPort(originalUpstream) if err != nil { - return nil, err + origHost = originalUpstream + origPort = "853" } - if len(servers) == 0 { - return nil, fmt.Errorf("no addresses resolved for DoT server") + + // Get connection target from resolved + resolvedHost, resolvedPort, _ := net.SplitHostPort(resolvedUpstream) + if resolvedHost == "" { + resolvedHost = resolvedUpstream + resolvedPort = origPort } client := &dns.Client{ Net: "tcp-tls", Dialer: GetDialer(bypass), - Timeout: 5 * time.Second, + Timeout: 6 * time.Second, TLSConfig: &tls.Config{ - ServerName: sni, + ServerName: origHost, + MinVersion: tls.VersionTLS12, }, } + return DoTTransport{ Client: client, - Servers: servers, + Servers: []string{net.JoinHostPort(resolvedHost, resolvedPort)}, }, nil - default: // plain DNS - _, addr, err := parseUpstream(upstream) - if err != nil { - return nil, err - } - servers, _, err := resolveServerAddrs(ctx, addr, bypass, "53", underlying) - if err != nil { - return nil, err + default: // plain + host, port, _ := net.SplitHostPort(resolvedUpstream) + if host == "" { + host = resolvedUpstream + port = "53" } client := &dns.Client{ @@ -446,7 +452,7 @@ func buildTransport( } return PlainTransport{ Client: client, - Servers: servers, + Servers: []string{net.JoinHostPort(host, port)}, }, nil } } diff --git a/tunnel/tools/libwg-go/dns/dns_jni.c b/tunnel/tools/libwg-go/dns/dns_jni.c index 2c53df55..052fbc6b 100644 --- a/tunnel/tools/libwg-go/dns/dns_jni.c +++ b/tunnel/tools/libwg-go/dns/dns_jni.c @@ -6,45 +6,45 @@ struct go_string { const char *str; long n; }; extern char* ResolveBootstrap( const char* host, const char* protocol, - const char* upstream, - const char* underlyingDnsServers, + const char* resolvedUpstream, + const char* originalUpstream, int bypass); JNIEXPORT jstring JNICALL -Java_com_zaneschepke_tunnel_DnsConfigManager_resolveBootstrap( +Java_com_zaneschepke_tunnel_backend_dns_NativeDnsResolver_resolveBootstrap( JNIEnv* env, jclass clazz, jstring host, jstring protocol, - jstring upstream, - jstring underlyingDnsServers, + jstring resolvedUpstream, + jstring originalUpstream, jint bypass) { - if (host == NULL || protocol == NULL || upstream == NULL || underlyingDnsServers == NULL) { + if (host == NULL || protocol == NULL || resolvedUpstream == NULL || originalUpstream == NULL) { return (*env)->NewStringUTF(env, "ERR|invalid arguments"); } - const char* chost = (*env)->GetStringUTFChars(env, host, NULL); - const char* cprotocol = (*env)->GetStringUTFChars(env, protocol, NULL); - const char* cupstream = (*env)->GetStringUTFChars(env, upstream, NULL); - const char* cunderlying = (*env)->GetStringUTFChars(env, underlyingDnsServers, NULL); + const char* chost = (*env)->GetStringUTFChars(env, host, NULL); + const char* cprotocol = (*env)->GetStringUTFChars(env, protocol, NULL); + const char* cresolvedUpstream = (*env)->GetStringUTFChars(env, resolvedUpstream, NULL); + const char* coriginalUpstream = (*env)->GetStringUTFChars(env, originalUpstream, NULL); - if (chost == NULL || cprotocol == NULL || cupstream == NULL || cunderlying == NULL) { + if (chost == NULL || cprotocol == NULL || cresolvedUpstream == NULL || coriginalUpstream == NULL) { return (*env)->NewStringUTF(env, "ERR|out of memory"); } char* resultC = ResolveBootstrap( chost, cprotocol, - cupstream, - cunderlying, + cresolvedUpstream, + coriginalUpstream, bypass ? 1 : 0 ); (*env)->ReleaseStringUTFChars(env, host, chost); (*env)->ReleaseStringUTFChars(env, protocol, cprotocol); - (*env)->ReleaseStringUTFChars(env, upstream, cupstream); - (*env)->ReleaseStringUTFChars(env, underlyingDnsServers, cunderlying); + (*env)->ReleaseStringUTFChars(env, resolvedUpstream, cresolvedUpstream); + (*env)->ReleaseStringUTFChars(env, originalUpstream, coriginalUpstream); if (resultC == NULL) { return (*env)->NewStringUTF(env, "ERR|null response");