diff --git a/gemini/gemini.py b/gemini/gemini.py index 807ed84..464b9e8 100644 --- a/gemini/gemini.py +++ b/gemini/gemini.py @@ -32,12 +32,20 @@ def main(): sys.exit(1) client = genai.Client() - context_data = {"file_ids": [], "cache_id": None} + + # Updated structure to track caches and their associated file states by model + context_data = {"file_ids": [], "caches": {}} if args.context and os.path.exists(args.context): try: with open(args.context, "r") as f: - context_data = json.load(f) + loaded_data = json.load(f) + context_data["file_ids"] = loaded_data.get("file_ids", []) + context_data["caches"] = loaded_data.get("caches", {}) + + # Handle migration from old format + if "cache_id" in loaded_data and loaded_data["cache_id"]: + print("Warning: Old context format detected. Please use -d to destroy and start fresh, or unexpected cache behavior may occur.", file=sys.stderr) except json.JSONDecodeError: print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr) @@ -46,12 +54,12 @@ def main(): # --------------------------------------------------------- if args.destroy: print("Destroying server resources and local context...") - if context_data.get("cache_id"): + for model_name, cache_info in context_data.get("caches", {}).items(): try: - client.caches.delete(name=context_data["cache_id"]) - print(f"Deleted cache: {context_data['cache_id']}") + client.caches.delete(name=cache_info["cache_id"]) + print(f"Deleted cache for {model_name}: {cache_info['cache_id']}") except Exception as e: - print(f"Warning: Failed to delete cache. {e}", file=sys.stderr) + print(f"Warning: Failed to delete cache for {model_name}. {e}", file=sys.stderr) for file_id in context_data.get("file_ids", []): try: @@ -71,7 +79,6 @@ def main(): # --------------------------------------------------------- # UPLOAD LOGIC # --------------------------------------------------------- - new_files_added = False if args.files: for file_path in args.files: if not os.path.exists(file_path): @@ -84,7 +91,6 @@ def main(): if uploaded_file.name not in context_data["file_ids"]: context_data["file_ids"].append(uploaded_file.name) - new_files_added = True if args.context: with open(args.context, "w") as f: @@ -100,23 +106,31 @@ def main(): "and output raw data only. If not specific data format is suggested, you can answer with conversational text." ) - # If new files were added but a cache already exists, we must destroy the stale cache. - if new_files_added and context_data.get("cache_id"): - print("New files detected. Destroying stale cache to rebuild...") - try: - client.caches.delete(name=context_data["cache_id"]) - except Exception as e: - print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr) - context_data["cache_id"] = None - cache_too_small = False file_objects = [] + active_cache_id = None if context_data.get("file_ids"): file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] + + model_cache_info = context_data["caches"].get(args.model) + current_files_set = set(context_data["file_ids"]) + rebuild_cache = False + + if model_cache_info: + cached_files_set = set(model_cache_info.get("cached_file_ids", [])) + if current_files_set != cached_files_set: + print(f"File list changed. Destroying stale {args.model} cache to rebuild...") + try: + client.caches.delete(name=model_cache_info["cache_id"]) + except Exception as e: + print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr) + rebuild_cache = True + else: + rebuild_cache = True - if not context_data.get("cache_id"): - print("Attempting to create Context Cache on Google's servers...") + if rebuild_cache: + print(f"Attempting to create Context Cache for {args.model}...") try: cache = client.caches.create( model=args.model, @@ -126,7 +140,13 @@ def main(): ttl="3600s" ) ) - context_data["cache_id"] = cache.name + + context_data["caches"][args.model] = { + "cache_id": cache.name, + "cached_file_ids": list(context_data["file_ids"]) + } + active_cache_id = cache.name + if args.context: with open(args.context, "w") as f: json.dump(context_data, f, indent=4) @@ -139,12 +159,13 @@ def main(): else: raise e - elif not cache_too_small: - print(f"Loading existing cache: {context_data['cache_id']}") + elif not cache_too_small and model_cache_info: + active_cache_id = model_cache_info["cache_id"] + print(f"Loading existing cache for {args.model}: {active_cache_id}") print("Extending cache TTL by 60 minutes...") try: client.caches.update( - name=context_data["cache_id"], + name=active_cache_id, config=types.UpdateCachedContentConfig(ttl="3600s") ) except Exception as e: @@ -161,8 +182,8 @@ def main(): generation_contents = [] - if context_data.get("cache_id") and not cache_too_small: - config_kwargs["cached_content"] = context_data["cache_id"] + if active_cache_id and not cache_too_small: + config_kwargs["cached_content"] = active_cache_id else: generation_contents.extend(file_objects) config_kwargs["system_instruction"] = system_instruction @@ -198,10 +219,10 @@ def main(): # --------------------------------------------------------- if not args.context and not args.destroy: print("\n[Transient Mode] Cleaning up resources...") - if context_data.get("cache_id"): + for model_name, cache_info in context_data.get("caches", {}).items(): try: - client.caches.delete(name=context_data["cache_id"]) - print(f"Deleted cache: {context_data['cache_id']}") + client.caches.delete(name=cache_info["cache_id"]) + print(f"Deleted cache for {model_name}: {cache_info['cache_id']}") except Exception as e: print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)