Making the cached content model aware

This commit is contained in:
2026-06-03 15:30:02 -07:00
parent a0412541a8
commit 5d4184f214
+49 -28
View File
@@ -32,12 +32,20 @@ def main():
sys.exit(1) sys.exit(1)
client = genai.Client() client = genai.Client()
context_data = {"file_ids": [], "cache_id": None}
# Updated structure to track caches and their associated file states by model
context_data = {"file_ids": [], "caches": {}}
if args.context and os.path.exists(args.context): if args.context and os.path.exists(args.context):
try: try:
with open(args.context, "r") as f: with open(args.context, "r") as f:
context_data = json.load(f) loaded_data = json.load(f)
context_data["file_ids"] = loaded_data.get("file_ids", [])
context_data["caches"] = loaded_data.get("caches", {})
# Handle migration from old format
if "cache_id" in loaded_data and loaded_data["cache_id"]:
print("Warning: Old context format detected. Please use -d to destroy and start fresh, or unexpected cache behavior may occur.", file=sys.stderr)
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr) print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
@@ -46,12 +54,12 @@ def main():
# --------------------------------------------------------- # ---------------------------------------------------------
if args.destroy: if args.destroy:
print("Destroying server resources and local context...") print("Destroying server resources and local context...")
if context_data.get("cache_id"): for model_name, cache_info in context_data.get("caches", {}).items():
try: try:
client.caches.delete(name=context_data["cache_id"]) client.caches.delete(name=cache_info["cache_id"])
print(f"Deleted cache: {context_data['cache_id']}") print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
except Exception as e: except Exception as e:
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr) print(f"Warning: Failed to delete cache for {model_name}. {e}", file=sys.stderr)
for file_id in context_data.get("file_ids", []): for file_id in context_data.get("file_ids", []):
try: try:
@@ -71,7 +79,6 @@ def main():
# --------------------------------------------------------- # ---------------------------------------------------------
# UPLOAD LOGIC # UPLOAD LOGIC
# --------------------------------------------------------- # ---------------------------------------------------------
new_files_added = False
if args.files: if args.files:
for file_path in args.files: for file_path in args.files:
if not os.path.exists(file_path): if not os.path.exists(file_path):
@@ -84,7 +91,6 @@ def main():
if uploaded_file.name not in context_data["file_ids"]: if uploaded_file.name not in context_data["file_ids"]:
context_data["file_ids"].append(uploaded_file.name) context_data["file_ids"].append(uploaded_file.name)
new_files_added = True
if args.context: if args.context:
with open(args.context, "w") as f: with open(args.context, "w") as f:
@@ -100,23 +106,31 @@ def main():
"and output raw data only. If not specific data format is suggested, you can answer with conversational text." "and output raw data only. If not specific data format is suggested, you can answer with conversational text."
) )
# If new files were added but a cache already exists, we must destroy the stale cache.
if new_files_added and context_data.get("cache_id"):
print("New files detected. Destroying stale cache to rebuild...")
try:
client.caches.delete(name=context_data["cache_id"])
except Exception as e:
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
context_data["cache_id"] = None
cache_too_small = False cache_too_small = False
file_objects = [] file_objects = []
active_cache_id = None
if context_data.get("file_ids"): if context_data.get("file_ids"):
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
model_cache_info = context_data["caches"].get(args.model)
current_files_set = set(context_data["file_ids"])
rebuild_cache = False
if model_cache_info:
cached_files_set = set(model_cache_info.get("cached_file_ids", []))
if current_files_set != cached_files_set:
print(f"File list changed. Destroying stale {args.model} cache to rebuild...")
try:
client.caches.delete(name=model_cache_info["cache_id"])
except Exception as e:
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
rebuild_cache = True
else:
rebuild_cache = True
if not context_data.get("cache_id"): if rebuild_cache:
print("Attempting to create Context Cache on Google's servers...") print(f"Attempting to create Context Cache for {args.model}...")
try: try:
cache = client.caches.create( cache = client.caches.create(
model=args.model, model=args.model,
@@ -126,7 +140,13 @@ def main():
ttl="3600s" ttl="3600s"
) )
) )
context_data["cache_id"] = cache.name
context_data["caches"][args.model] = {
"cache_id": cache.name,
"cached_file_ids": list(context_data["file_ids"])
}
active_cache_id = cache.name
if args.context: if args.context:
with open(args.context, "w") as f: with open(args.context, "w") as f:
json.dump(context_data, f, indent=4) json.dump(context_data, f, indent=4)
@@ -139,12 +159,13 @@ def main():
else: else:
raise e raise e
elif not cache_too_small: elif not cache_too_small and model_cache_info:
print(f"Loading existing cache: {context_data['cache_id']}") active_cache_id = model_cache_info["cache_id"]
print(f"Loading existing cache for {args.model}: {active_cache_id}")
print("Extending cache TTL by 60 minutes...") print("Extending cache TTL by 60 minutes...")
try: try:
client.caches.update( client.caches.update(
name=context_data["cache_id"], name=active_cache_id,
config=types.UpdateCachedContentConfig(ttl="3600s") config=types.UpdateCachedContentConfig(ttl="3600s")
) )
except Exception as e: except Exception as e:
@@ -161,8 +182,8 @@ def main():
generation_contents = [] generation_contents = []
if context_data.get("cache_id") and not cache_too_small: if active_cache_id and not cache_too_small:
config_kwargs["cached_content"] = context_data["cache_id"] config_kwargs["cached_content"] = active_cache_id
else: else:
generation_contents.extend(file_objects) generation_contents.extend(file_objects)
config_kwargs["system_instruction"] = system_instruction config_kwargs["system_instruction"] = system_instruction
@@ -198,10 +219,10 @@ def main():
# --------------------------------------------------------- # ---------------------------------------------------------
if not args.context and not args.destroy: if not args.context and not args.destroy:
print("\n[Transient Mode] Cleaning up resources...") print("\n[Transient Mode] Cleaning up resources...")
if context_data.get("cache_id"): for model_name, cache_info in context_data.get("caches", {}).items():
try: try:
client.caches.delete(name=context_data["cache_id"]) client.caches.delete(name=cache_info["cache_id"])
print(f"Deleted cache: {context_data['cache_id']}") print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
except Exception as e: except Exception as e:
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr) print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)