diff --git a/gemini/gemini.py b/gemini/gemini.py index 464b9e8..ff21065 100644 --- a/gemini/gemini.py +++ b/gemini/gemini.py @@ -9,13 +9,15 @@ from google.genai import types def main(): parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching") parser.add_argument("-c", "--context", type=str, default=None, - help="Path to context file. If omitted, files/caches are deleted after execution.") + help="Path to context file. If omitted, transient mode is used.") parser.add_argument("-f", "--files", nargs="+", default=[], help="Files to upload to the Gemini API") parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite", help="The model to use (default: gemini-3.1-flash-lite)") parser.add_argument("-d", "--destroy", action="store_true", - help="Destroy the cloud files and cache, and delete the local context file") + help="Destroy cloud files/cache, and delete local context") + parser.add_argument("-x", "--clear-history", action="store_true", + help="Clear the conversation history without destroying files/caches") parser.add_argument("-o", "--output", type=str, help="Direct the raw output to a specific file instead of stdout") parser.add_argument("-p", "--prompt", type=str, @@ -24,7 +26,6 @@ def main(): help="Positional arguments treated as the prompt if -p is omitted") args = parser.parse_args() - prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None if not os.environ.get("GEMINI_API_KEY"): @@ -33,8 +34,7 @@ def main(): client = genai.Client() - # Updated structure to track caches and their associated file states by model - context_data = {"file_ids": [], "caches": {}} + context_data = {"file_ids": [], "caches": {}, "history": []} if args.context and os.path.exists(args.context): try: @@ -42,13 +42,22 @@ def main(): loaded_data = json.load(f) context_data["file_ids"] = loaded_data.get("file_ids", []) context_data["caches"] = loaded_data.get("caches", {}) - - # Handle migration from old format - if "cache_id" in loaded_data and loaded_data["cache_id"]: - print("Warning: Old context format detected. Please use -d to destroy and start fresh, or unexpected cache behavior may occur.", file=sys.stderr) + context_data["history"] = loaded_data.get("history", []) except json.JSONDecodeError: print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr) + # --------------------------------------------------------- + # CLEAR HISTORY FLAG + # --------------------------------------------------------- + if args.clear_history: + context_data["history"] = [] + if args.context: + with open(args.context, "w") as f: + json.dump(context_data, f, indent=4) + print("Conversation history cleared.") + if not prompt_text and not args.files and not args.destroy: + return + # --------------------------------------------------------- # DESTROY FLAG LOGIC # --------------------------------------------------------- @@ -172,7 +181,7 @@ def main(): print(f"Warning: Failed to update cache TTL. {e}") # --------------------------------------------------------- - # GENERATION LOGIC + # GENERATION LOGIC (WITH HISTORY) # --------------------------------------------------------- if prompt_text: config_kwargs = { @@ -180,39 +189,72 @@ def main(): "temperature": 0.0 } - generation_contents = [] - if active_cache_id and not cache_too_small: config_kwargs["cached_content"] = active_cache_id else: - generation_contents.extend(file_objects) config_kwargs["system_instruction"] = system_instruction + + # Build the chat payload using strict dictionary representations + api_contents = [] + files_added_to_payload = False + + for msg in context_data.get("history", []): + parts = [] + # If we aren't using a cache, attach un-cached files to the very first user message + if msg["role"] == "user" and not files_added_to_payload and not active_cache_id and file_objects: + for f in file_objects: + parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}}) + files_added_to_payload = True + + parts.append({"text": msg["text"]}) + api_contents.append({"role": msg["role"], "parts": parts}) + + # Add the current prompt + current_parts = [] + if not files_added_to_payload and not active_cache_id and file_objects: + for f in file_objects: + current_parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}}) + files_added_to_payload = True - generation_contents.append(prompt_text) + current_parts.append({"text": prompt_text}) + api_contents.append({"role": "user", "parts": current_parts}) + config = types.GenerateContentConfig(**config_kwargs) print("Generating response (this may take a moment for large outputs)...") response_stream = client.models.generate_content_stream( model=args.model, - contents=generation_contents, + contents=api_contents, config=config ) + full_response_text = "" + if args.output: with open(args.output, "w") as f: for chunk in response_stream: if chunk.text: f.write(chunk.text) f.flush() + full_response_text += chunk.text print(f"\nDone! Raw output saved directly to {args.output}") else: print("-" * 40) for chunk in response_stream: if chunk.text: print(chunk.text, end="", flush=True) + full_response_text += chunk.text print("\n" + "-" * 40) + # Append this turn to the local history and save + context_data["history"].append({"role": "user", "text": prompt_text}) + context_data["history"].append({"role": "model", "text": full_response_text}) + + if args.context: + with open(args.context, "w") as f: + json.dump(context_data, f, indent=4) + finally: # --------------------------------------------------------- # TRANSIENT MODE CLEANUP