From 8b1007eee0058fe646ac37949aca72cf3d78d7ee Mon Sep 17 00:00:00 2001 From: jh-block Date: Wed, 8 Apr 2026 21:22:35 +0200 Subject: [PATCH] fix: correctly parse thinking content from gemma4 (#8407) Signed-off-by: jh-block --- Cargo.lock | 8 ++-- crates/goose/Cargo.toml | 4 +- crates/goose/src/providers/local_inference.rs | 12 ++++- .../local_inference/inference_engine.rs | 44 +++++++++++++------ .../local_inference/inference_native_tools.rs | 23 +++++++++- .../local_inference/local_model_registry.rs | 9 ++-- 6 files changed, 73 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ef300971e..3a1eca7898 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5860,9 +5860,9 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" [[package]] name = "llama-cpp-2" -version = "0.1.142" +version = "0.1.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f04fe7a4b6836bff1966eb8c35e5e121f573641887526d9f997e9c10a64d1bc" +checksum = "d564eb5d7ae88f596e7636ffd549e0f27f4c938a7c7841bb91e2a92c248c9ccb" dependencies = [ "encoding_rs", "enumflags2", @@ -5874,9 +5874,9 @@ dependencies = [ [[package]] name = "llama-cpp-sys-2" -version = "0.1.142" +version = "0.1.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95ef4708a34c686a2686f9a57050df3b5384ff346dcbfb8ee47b5814f4dcfc90" +checksum = "2f2cf3435dbadb87817e0a95325c818cd89d43026e8ba1ddd32f1d980d96f33d" dependencies = [ "bindgen", "cc", diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 0a892e444f..f2ae36e110 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -179,7 +179,7 @@ tree-sitter-typescript = { workspace = true } which = { workspace = true } pctx_code_mode = { version = "^0.3.0", optional = true } pulldown-cmark = "0.13.0" -llama-cpp-2 = { version = "0.1.142", features = ["sampler"], optional = true } +llama-cpp-2 = { version = "0.1.143", features = ["sampler"], optional = true } encoding_rs = "0.8.35" pastey = "0.2.1" shell-words = { workspace = true } @@ -197,7 +197,7 @@ keyring = { version = "3.6.2", features = ["windows-native"] } [target.'cfg(target_os = "macos")'.dependencies] candle-core = { version = "0.9", default-features = false, features = ["metal"], optional = true } candle-nn = { version = "0.9", default-features = false, features = ["metal"], optional = true } -llama-cpp-2 = { version = "0.1.142", features = ["sampler", "metal"], optional = true } +llama-cpp-2 = { version = "0.1.143", features = ["sampler", "metal"], optional = true } keyring = { version = "3.6.2", features = ["apple-native"] } [target.'cfg(target_os = "linux")'.dependencies] diff --git a/crates/goose/src/providers/local_inference.rs b/crates/goose/src/providers/local_inference.rs index f6b6fb97a9..4ff331c5d4 100644 --- a/crates/goose/src/providers/local_inference.rs +++ b/crates/goose/src/providers/local_inference.rs @@ -122,12 +122,20 @@ pub fn resolve_model_path( usize, crate::providers::local_inference::local_model_registry::ModelSettings, )> { - use crate::providers::local_inference::local_model_registry::get_registry; + use crate::providers::local_inference::local_model_registry::{ + default_settings_for_model, get_registry, + }; if let Ok(registry) = get_registry().lock() { if let Some(entry) = registry.get_model(model_id) { let ctx = entry.settings.context_size.unwrap_or(0) as usize; - return Some((entry.local_path.clone(), ctx, entry.settings.clone())); + let mut settings = entry.settings.clone(); + // Capability flags are inherent to the model family, not user-configurable. + // Re-derive them so that registry entries persisted before a model was + // recognized (or with a different quantization) still get the right behavior. + let defaults = default_settings_for_model(model_id); + settings.native_tool_calling = defaults.native_tool_calling; + return Some((entry.local_path.clone(), ctx, settings)); } } diff --git a/crates/goose/src/providers/local_inference/inference_engine.rs b/crates/goose/src/providers/local_inference/inference_engine.rs index 90940f357e..23ec84f293 100644 --- a/crates/goose/src/providers/local_inference/inference_engine.rs +++ b/crates/goose/src/providers/local_inference/inference_engine.rs @@ -118,15 +118,15 @@ pub(super) fn effective_context_size( ) -> usize { let limit = context_cap(settings, context_limit, n_ctx_train, memory_max_ctx); let min_generation_headroom = 512; - let needed = prompt_token_count + min_generation_headroom; - if needed > limit { + if prompt_token_count + min_generation_headroom > limit { tracing::warn!( - "Prompt ({} tokens) + headroom exceeds context limit ({}), capping to limit", + "Prompt ({} tokens) + minimum headroom ({}) exceeds context limit ({})", prompt_token_count, + min_generation_headroom, limit, ); } - needed.min(limit) + limit } pub(super) fn build_context_params( @@ -269,7 +269,8 @@ pub(super) enum TokenAction { /// Run the autoregressive generation loop. Calls `on_piece` for each non-empty /// token piece. The callback returns `TokenAction::Stop` to break early. -/// Returns the total number of generated tokens. +/// Returns the total number of generated tokens, or `ContextLengthExceeded` +/// if the model exhausted the available context window. pub(super) fn generation_loop( model: &LlamaModel, ctx: &mut llama_cpp_2::context::LlamaContext<'_>, @@ -279,19 +280,25 @@ pub(super) fn generation_loop( mut on_piece: impl FnMut(&str) -> Result, ) -> Result { let mut sampler = build_sampler(settings); + let context_headroom = effective_ctx.saturating_sub(prompt_token_count); let max_output = if let Some(max) = settings.max_output_tokens { - effective_ctx.saturating_sub(prompt_token_count).min(max) + context_headroom.min(max) } else { - effective_ctx.saturating_sub(prompt_token_count) + context_headroom }; + let hit_context_limit = settings + .max_output_tokens + .is_none_or(|max| context_headroom <= max); let mut decoder = encoding_rs::UTF_8.new_decoder(); let mut output_token_count: i32 = 0; + let mut exhausted_loop = true; for _ in 0..max_output { let token = sampler.sample(ctx, -1); sampler.accept(token); if model.is_eog_token(token) { + exhausted_loop = false; break; } @@ -302,6 +309,7 @@ pub(super) fn generation_loop( .map_err(|e| ProviderError::ExecutionError(format!("Failed to decode token: {}", e)))?; if !piece.is_empty() && matches!(on_piece(&piece)?, TokenAction::Stop) { + exhausted_loop = false; break; } @@ -312,6 +320,16 @@ pub(super) fn generation_loop( .map_err(|e| ProviderError::ExecutionError(format!("Decode failed: {}", e)))?; } + if exhausted_loop && hit_context_limit { + return Err(ProviderError::ContextLengthExceeded(format!( + "Generation exhausted context window ({} prompt + {} generated = {} of {} limit)", + prompt_token_count, + output_token_count, + prompt_token_count as i32 + output_token_count, + effective_ctx, + ))); + } + Ok(output_token_count) } @@ -325,10 +343,10 @@ mod tests { } #[test] - fn test_effective_context_size_basic() { + fn test_effective_context_size_uses_full_limit() { assert_eq!( effective_context_size(100, &default_settings(), 4096, 4096, None), - 612 + 4096 ); } @@ -336,7 +354,7 @@ mod tests { fn test_effective_context_size_capped_by_limit() { assert_eq!( effective_context_size(100, &default_settings(), 1024, 8192, None), - 612 + 1024 ); } @@ -344,12 +362,12 @@ mod tests { fn test_effective_context_size_capped_by_memory() { assert_eq!( effective_context_size(100, &default_settings(), 4096, 4096, Some(800)), - 612 + 800 ); } #[test] - fn test_effective_context_size_memory_smaller_than_needed() { + fn test_effective_context_size_memory_smaller_than_prompt() { assert_eq!( effective_context_size(600, &default_settings(), 4096, 4096, Some(700)), 700 @@ -360,7 +378,7 @@ mod tests { fn test_effective_context_size_zero_limit_uses_train() { assert_eq!( effective_context_size(100, &default_settings(), 0, 2048, None), - 612 + 2048 ); } diff --git a/crates/goose/src/providers/local_inference/inference_native_tools.rs b/crates/goose/src/providers/local_inference/inference_native_tools.rs index f26847f424..3d12d37004 100644 --- a/crates/goose/src/providers/local_inference/inference_native_tools.rs +++ b/crates/goose/src/providers/local_inference/inference_native_tools.rs @@ -33,12 +33,12 @@ pub(super) fn generate_with_native_tools( tool_choice: None, json_schema: None, grammar: None, - reasoning_format: None, + reasoning_format: Some("auto"), chat_template_kwargs: None, add_generation_prompt: true, use_jinja: true, parallel_tool_calls: false, - enable_thinking: false, + enable_thinking: true, add_bos: false, add_eos: false, parse_tool_calls: true, @@ -134,6 +134,18 @@ pub(super) fn generate_with_native_tools( Ok(deltas) => { for delta_json in deltas { if let Ok(delta) = serde_json::from_str::(&delta_json) { + // Stream thinking/reasoning content + if let Some(reasoning) = + delta.get("reasoning_content").and_then(|v| v.as_str()) + { + if !reasoning.is_empty() { + let mut msg = Message::assistant().with_thinking(reasoning, ""); + msg.id = Some(message_id.to_string()); + if tx.blocking_send(Ok((Some(msg), None))).is_err() { + return Ok(TokenAction::Stop); + } + } + } // Stream content text to the UI if let Some(content) = delta.get("content").and_then(|v| v.as_str()) { if !content.is_empty() { @@ -181,6 +193,13 @@ pub(super) fn generate_with_native_tools( if let Ok(final_deltas) = stream_parser.update("", false) { for delta_json in final_deltas { if let Ok(delta) = serde_json::from_str::(&delta_json) { + if let Some(reasoning) = delta.get("reasoning_content").and_then(|v| v.as_str()) { + if !reasoning.is_empty() { + let mut msg = Message::assistant().with_thinking(reasoning, ""); + msg.id = Some(message_id.to_string()); + let _ = tx.blocking_send(Ok((Some(msg), None))); + } + } if let Some(content) = delta.get("content").and_then(|v| v.as_str()) { if !content.is_empty() { let mut msg = Message::assistant().with_text(content); diff --git a/crates/goose/src/providers/local_inference/local_model_registry.rs b/crates/goose/src/providers/local_inference/local_model_registry.rs index d096824c57..49a397057c 100644 --- a/crates/goose/src/providers/local_inference/local_model_registry.rs +++ b/crates/goose/src/providers/local_inference/local_model_registry.rs @@ -127,15 +127,16 @@ pub const FEATURED_MODELS: &[FeaturedModel] = &[ pub fn default_settings_for_model(model_id: &str) -> ModelSettings { use super::hf_models::parse_model_spec; - let native = FEATURED_MODELS.iter().any(|m| { - if let Ok((repo_id, quant)) = parse_model_spec(m.spec) { - model_id_from_repo(&repo_id, &quant) == model_id && m.native_tool_calling + let model_repo = model_id.split(':').next().unwrap_or(model_id); + let featured = FEATURED_MODELS.iter().find(|m| { + if let Ok((repo_id, _quant)) = parse_model_spec(m.spec) { + repo_id == model_repo } else { false } }); ModelSettings { - native_tool_calling: native, + native_tool_calling: featured.is_some_and(|m| m.native_tool_calling), ..ModelSettings::default() } }