fix: correctly parse thinking content from gemma4 (#8407)

Signed-off-by: jh-block <jhugo@block.xyz>
This commit is contained in:
jh-block
2026-04-08 21:22:35 +02:00
committed by GitHub
parent a7881a2e6c
commit 8b1007eee0
6 changed files with 73 additions and 27 deletions
Generated
+4 -4
View File
@@ -5860,9 +5860,9 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092"
[[package]]
name = "llama-cpp-2"
version = "0.1.142"
version = "0.1.143"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f04fe7a4b6836bff1966eb8c35e5e121f573641887526d9f997e9c10a64d1bc"
checksum = "d564eb5d7ae88f596e7636ffd549e0f27f4c938a7c7841bb91e2a92c248c9ccb"
dependencies = [
"encoding_rs",
"enumflags2",
@@ -5874,9 +5874,9 @@ dependencies = [
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.142"
version = "0.1.143"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95ef4708a34c686a2686f9a57050df3b5384ff346dcbfb8ee47b5814f4dcfc90"
checksum = "2f2cf3435dbadb87817e0a95325c818cd89d43026e8ba1ddd32f1d980d96f33d"
dependencies = [
"bindgen",
"cc",
+2 -2
View File
@@ -179,7 +179,7 @@ tree-sitter-typescript = { workspace = true }
which = { workspace = true }
pctx_code_mode = { version = "^0.3.0", optional = true }
pulldown-cmark = "0.13.0"
llama-cpp-2 = { version = "0.1.142", features = ["sampler"], optional = true }
llama-cpp-2 = { version = "0.1.143", features = ["sampler"], optional = true }
encoding_rs = "0.8.35"
pastey = "0.2.1"
shell-words = { workspace = true }
@@ -197,7 +197,7 @@ keyring = { version = "3.6.2", features = ["windows-native"] }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { version = "0.9", default-features = false, features = ["metal"], optional = true }
candle-nn = { version = "0.9", default-features = false, features = ["metal"], optional = true }
llama-cpp-2 = { version = "0.1.142", features = ["sampler", "metal"], optional = true }
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "metal"], optional = true }
keyring = { version = "3.6.2", features = ["apple-native"] }
[target.'cfg(target_os = "linux")'.dependencies]
+10 -2
View File
@@ -122,12 +122,20 @@ pub fn resolve_model_path(
usize,
crate::providers::local_inference::local_model_registry::ModelSettings,
)> {
use crate::providers::local_inference::local_model_registry::get_registry;
use crate::providers::local_inference::local_model_registry::{
default_settings_for_model, get_registry,
};
if let Ok(registry) = get_registry().lock() {
if let Some(entry) = registry.get_model(model_id) {
let ctx = entry.settings.context_size.unwrap_or(0) as usize;
return Some((entry.local_path.clone(), ctx, entry.settings.clone()));
let mut settings = entry.settings.clone();
// Capability flags are inherent to the model family, not user-configurable.
// Re-derive them so that registry entries persisted before a model was
// recognized (or with a different quantization) still get the right behavior.
let defaults = default_settings_for_model(model_id);
settings.native_tool_calling = defaults.native_tool_calling;
return Some((entry.local_path.clone(), ctx, settings));
}
}
@@ -118,15 +118,15 @@ pub(super) fn effective_context_size(
) -> usize {
let limit = context_cap(settings, context_limit, n_ctx_train, memory_max_ctx);
let min_generation_headroom = 512;
let needed = prompt_token_count + min_generation_headroom;
if needed > limit {
if prompt_token_count + min_generation_headroom > limit {
tracing::warn!(
"Prompt ({} tokens) + headroom exceeds context limit ({}), capping to limit",
"Prompt ({} tokens) + minimum headroom ({}) exceeds context limit ({})",
prompt_token_count,
min_generation_headroom,
limit,
);
}
needed.min(limit)
limit
}
pub(super) fn build_context_params(
@@ -269,7 +269,8 @@ pub(super) enum TokenAction {
/// Run the autoregressive generation loop. Calls `on_piece` for each non-empty
/// token piece. The callback returns `TokenAction::Stop` to break early.
/// Returns the total number of generated tokens.
/// Returns the total number of generated tokens, or `ContextLengthExceeded`
/// if the model exhausted the available context window.
pub(super) fn generation_loop(
model: &LlamaModel,
ctx: &mut llama_cpp_2::context::LlamaContext<'_>,
@@ -279,19 +280,25 @@ pub(super) fn generation_loop(
mut on_piece: impl FnMut(&str) -> Result<TokenAction, ProviderError>,
) -> Result<i32, ProviderError> {
let mut sampler = build_sampler(settings);
let context_headroom = effective_ctx.saturating_sub(prompt_token_count);
let max_output = if let Some(max) = settings.max_output_tokens {
effective_ctx.saturating_sub(prompt_token_count).min(max)
context_headroom.min(max)
} else {
effective_ctx.saturating_sub(prompt_token_count)
context_headroom
};
let hit_context_limit = settings
.max_output_tokens
.is_none_or(|max| context_headroom <= max);
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut output_token_count: i32 = 0;
let mut exhausted_loop = true;
for _ in 0..max_output {
let token = sampler.sample(ctx, -1);
sampler.accept(token);
if model.is_eog_token(token) {
exhausted_loop = false;
break;
}
@@ -302,6 +309,7 @@ pub(super) fn generation_loop(
.map_err(|e| ProviderError::ExecutionError(format!("Failed to decode token: {}", e)))?;
if !piece.is_empty() && matches!(on_piece(&piece)?, TokenAction::Stop) {
exhausted_loop = false;
break;
}
@@ -312,6 +320,16 @@ pub(super) fn generation_loop(
.map_err(|e| ProviderError::ExecutionError(format!("Decode failed: {}", e)))?;
}
if exhausted_loop && hit_context_limit {
return Err(ProviderError::ContextLengthExceeded(format!(
"Generation exhausted context window ({} prompt + {} generated = {} of {} limit)",
prompt_token_count,
output_token_count,
prompt_token_count as i32 + output_token_count,
effective_ctx,
)));
}
Ok(output_token_count)
}
@@ -325,10 +343,10 @@ mod tests {
}
#[test]
fn test_effective_context_size_basic() {
fn test_effective_context_size_uses_full_limit() {
assert_eq!(
effective_context_size(100, &default_settings(), 4096, 4096, None),
612
4096
);
}
@@ -336,7 +354,7 @@ mod tests {
fn test_effective_context_size_capped_by_limit() {
assert_eq!(
effective_context_size(100, &default_settings(), 1024, 8192, None),
612
1024
);
}
@@ -344,12 +362,12 @@ mod tests {
fn test_effective_context_size_capped_by_memory() {
assert_eq!(
effective_context_size(100, &default_settings(), 4096, 4096, Some(800)),
612
800
);
}
#[test]
fn test_effective_context_size_memory_smaller_than_needed() {
fn test_effective_context_size_memory_smaller_than_prompt() {
assert_eq!(
effective_context_size(600, &default_settings(), 4096, 4096, Some(700)),
700
@@ -360,7 +378,7 @@ mod tests {
fn test_effective_context_size_zero_limit_uses_train() {
assert_eq!(
effective_context_size(100, &default_settings(), 0, 2048, None),
612
2048
);
}
@@ -33,12 +33,12 @@ pub(super) fn generate_with_native_tools(
tool_choice: None,
json_schema: None,
grammar: None,
reasoning_format: None,
reasoning_format: Some("auto"),
chat_template_kwargs: None,
add_generation_prompt: true,
use_jinja: true,
parallel_tool_calls: false,
enable_thinking: false,
enable_thinking: true,
add_bos: false,
add_eos: false,
parse_tool_calls: true,
@@ -134,6 +134,18 @@ pub(super) fn generate_with_native_tools(
Ok(deltas) => {
for delta_json in deltas {
if let Ok(delta) = serde_json::from_str::<Value>(&delta_json) {
// Stream thinking/reasoning content
if let Some(reasoning) =
delta.get("reasoning_content").and_then(|v| v.as_str())
{
if !reasoning.is_empty() {
let mut msg = Message::assistant().with_thinking(reasoning, "");
msg.id = Some(message_id.to_string());
if tx.blocking_send(Ok((Some(msg), None))).is_err() {
return Ok(TokenAction::Stop);
}
}
}
// Stream content text to the UI
if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
if !content.is_empty() {
@@ -181,6 +193,13 @@ pub(super) fn generate_with_native_tools(
if let Ok(final_deltas) = stream_parser.update("", false) {
for delta_json in final_deltas {
if let Ok(delta) = serde_json::from_str::<Value>(&delta_json) {
if let Some(reasoning) = delta.get("reasoning_content").and_then(|v| v.as_str()) {
if !reasoning.is_empty() {
let mut msg = Message::assistant().with_thinking(reasoning, "");
msg.id = Some(message_id.to_string());
let _ = tx.blocking_send(Ok((Some(msg), None)));
}
}
if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
if !content.is_empty() {
let mut msg = Message::assistant().with_text(content);
@@ -127,15 +127,16 @@ pub const FEATURED_MODELS: &[FeaturedModel] = &[
pub fn default_settings_for_model(model_id: &str) -> ModelSettings {
use super::hf_models::parse_model_spec;
let native = FEATURED_MODELS.iter().any(|m| {
if let Ok((repo_id, quant)) = parse_model_spec(m.spec) {
model_id_from_repo(&repo_id, &quant) == model_id && m.native_tool_calling
let model_repo = model_id.split(':').next().unwrap_or(model_id);
let featured = FEATURED_MODELS.iter().find(|m| {
if let Ok((repo_id, _quant)) = parse_model_spec(m.spec) {
repo_id == model_repo
} else {
false
}
});
ModelSettings {
native_tool_calling: native,
native_tool_calling: featured.is_some_and(|m| m.native_tool_calling),
..ModelSettings::default()
}
}