Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions src/google/adk/utils/instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from ..sessions.state import State

__all__ = [
'inject_session_state',
"inject_session_state",
]

logger = logging.getLogger('google_adk.' + __name__)
logger = logging.getLogger("google_adk." + __name__)


async def inject_session_state(
Expand Down Expand Up @@ -76,18 +76,23 @@ async def _async_sub(pattern, repl_async_fn, string) -> str:
result.append(replacement)
last_end = match.end()
result.append(string[last_end:])
return ''.join(result)
return "".join(result)

async def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip()
matched_text = match.group()

if matched_text.startswith("{{") and matched_text.endswith("}}"):
return matched_text[1:-1]
Comment on lines +84 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check correctly handles the documented {{...}} escaping. However, it also matches patterns with more than two braces, like {{{variable}}}. This changes the behavior for such patterns. Previously, {{{variable}}} would be processed by lstrip/rstrip, resulting in the substitution of variable. With this change, it will be unescaped to {{variable}} and rendered literally.

While the new behavior is arguably more predictable, it's an undocumented side effect of this fix. If this change is intentional, consider adding a test case for it. If not, you might want to make this check more specific to exactly two braces to avoid altering the behavior for other patterns.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please address this.


var_name = matched_text.lstrip("{").rstrip("}").strip()
optional = False
if var_name.endswith('?'):
if var_name.endswith("?"):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
var_name = var_name.removesuffix("?")
if var_name.startswith("artifact."):
var_name = var_name.removeprefix("artifact.")
if invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
raise ValueError("Artifact service is not initialized.")
artifact = await invocation_context.artifact_service.load_artifact(
app_name=invocation_context.session.app_name,
user_id=invocation_context.session.user_id,
Expand All @@ -97,31 +102,31 @@ async def _replace_match(match) -> str:
if artifact is None:
if optional:
logger.debug(
'Artifact %s not found, replacing with empty string', var_name
"Artifact %s not found, replacing with empty string", var_name
)
return ''
return ""
else:
raise KeyError(f'Artifact {var_name} not found.')
raise KeyError(f"Artifact {var_name} not found.")
return str(artifact)
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in invocation_context.session.state:
value = invocation_context.session.state[var_name]
if value is None:
return ''
return ""
return str(value)
else:
if optional:
logger.debug(
'Context variable %s not found, replacing with empty string',
"Context variable %s not found, replacing with empty string",
var_name,
)
return ''
return ""
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
raise KeyError(f"Context variable not found: `{var_name}`.")

return await _async_sub(r'{+[^{}]*}+', _replace_match, template)
return await _async_sub(r"{+[^{}]*}+", _replace_match, template)


def _is_valid_state_name(var_name):
Expand All @@ -138,12 +143,12 @@ def _is_valid_state_name(var_name):
Returns:
True if the variable name is a valid state name, False otherwise.
"""
parts = var_name.split(':')
parts = var_name.split(":")
if len(parts) == 1:
return var_name.isidentifier()

if len(parts) == 2:
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
if (parts[0] + ':') in prefixes:
if (parts[0] + ":") in prefixes:
return parts[1].isidentifier()
return False
71 changes: 71 additions & 0 deletions tests/unittests/utils/test_instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,74 @@ async def test_inject_session_state_with_optional_missing_state_returns_empty():
instruction_template, invocation_context
)
assert populated_instruction == "Optional value: "


@pytest.mark.asyncio
async def test_inject_session_state_with_double_brace_escaping():
instruction_template = "Example: {{user_id}}"
invocation_context = await _create_test_readonly_context()

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Example: {user_id}"


@pytest.mark.asyncio
async def test_inject_session_state_with_double_brace_escaping_and_normal_substitution():
instruction_template = "Hello {name}, example: {{variable}}"
invocation_context = await _create_test_readonly_context(
state={"name": "Alice"}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Hello Alice, example: {variable}"


@pytest.mark.asyncio
async def test_inject_session_state_with_python_fstring_example():
instruction_template = """
Example Python code:
logger.error(f"User not found: {{user_id}}")
"""
invocation_context = await _create_test_readonly_context()

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
expected = """
Example Python code:
logger.error(f"User not found: {user_id}")
"""
assert populated_instruction == expected


@pytest.mark.asyncio
async def test_inject_session_state_with_typescript_template_literal():
instruction_template = """
Example TypeScript code:
console.log(`User: ${{userId}}`);
"""
invocation_context = await _create_test_readonly_context()

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
expected = """
Example TypeScript code:
console.log(`User: ${userId}`);
"""
assert populated_instruction == expected


@pytest.mark.asyncio
async def test_inject_session_state_with_multiple_double_brace_patterns():
instruction_template = "Examples: {{var1}}, {{var2}}, {{var3}}"
invocation_context = await _create_test_readonly_context()

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Examples: {var1}, {var2}, {var3}"
Loading