diff --git a/gemini/gemini.py b/gemini/gemini.py index 98018cf..6ba3d8f 100644 --- a/gemini/gemini.py +++ b/gemini/gemini.py @@ -8,8 +8,8 @@ from google.genai import types def main(): parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching") - parser.add_argument("-c", "--context", type=str, default=".gemini_context.json", - help="Path to the context file (defaults to .gemini_context.json)") + parser.add_argument("-c", "--context", type=str, default=None, + help="Path to context file. If omitted, files/caches are deleted after execution.") parser.add_argument("-f", "--files", nargs="+", default=[], help="Files to upload to the Gemini API") parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite", @@ -25,20 +25,17 @@ def main(): args = parser.parse_args() - # Determine the prompt text prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None - # Authenticate: The new Client() automatically looks for the GEMINI_API_KEY environment variable. if not os.environ.get("GEMINI_API_KEY"): print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr) sys.exit(1) client = genai.Client() - context_data = {"file_ids": [], "cache_id": None} - # Load existing context if it exists - if os.path.exists(args.context): + # Load existing context if a context file is specified and exists + if args.context and os.path.exists(args.context): try: with open(args.context, "r") as f: context_data = json.load(f) @@ -64,36 +61,39 @@ def main(): except Exception as e: print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr) - if os.path.exists(args.context): + if args.context and os.path.exists(args.context): os.remove(args.context) print(f"Deleted local context file: {args.context}") print("Cleanup complete.") - return # Exit after destroying + return - # --------------------------------------------------------- - # UPLOAD LOGIC - # --------------------------------------------------------- - if args.files: - for file_path in args.files: - if not os.path.exists(file_path): - print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr) - continue - - print(f"Uploading '{file_path}'...") - uploaded_file = client.files.upload(file=file_path) - print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'") - - if uploaded_file.name not in context_data["file_ids"]: - context_data["file_ids"].append(uploaded_file.name) + # Wrap the main execution in a try/finally to guarantee cleanup on transient runs + try: + # --------------------------------------------------------- + # UPLOAD LOGIC + # --------------------------------------------------------- + if args.files: + for file_path in args.files: + if not os.path.exists(file_path): + print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr) + continue + + print(f"Uploading '{file_path}'...") + uploaded_file = client.files.upload(file=file_path) + print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'") + + if uploaded_file.name not in context_data["file_ids"]: + context_data["file_ids"].append(uploaded_file.name) - with open(args.context, "w") as f: - json.dump(context_data, f, indent=4) + if args.context: + with open(args.context, "w") as f: + json.dump(context_data, f, indent=4) - # --------------------------------------------------------- - # CACHE CREATION AND TTL EXTENSION LOGIC - # --------------------------------------------------------- - system_instruction = ( + # --------------------------------------------------------- + # CACHE CREATION LOGIC + # --------------------------------------------------------- + system_instruction = ( # "You are a strict data extraction tool. Output exactly what is requested " # "(e.g., CSV). Never use markdown formatting blocks (like ```csv). " # "Never include conversational text, greetings, or explanations. Output raw data only." @@ -101,98 +101,110 @@ def main(): "(e.g., CSV), then output exactly what was requested and never use markdown formatting blocks (like ```csv). " "If a specific format was requested, never include conversational text, greetings, or explanations; " "and output raw data only. If not specific data format is suggested, you can answer with conversational text." - ) - - cache_too_small = False - file_objects = [] - - if context_data.get("file_ids"): - # We always need the file objects just in case caching fails - file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] - - if not context_data.get("cache_id"): - print("Attempting to create Context Cache on Google's servers...") - - try: - # Attempt to create the cache - cache = client.caches.create( - model=args.model, - config=types.CreateCachedContentConfig( - contents=file_objects, - system_instruction=system_instruction, - ttl="3600s" - ) - ) - context_data["cache_id"] = cache.name - with open(args.context, "w") as f: - json.dump(context_data, f, indent=4) - print(f"Context Cache created: {cache.name}") - - except Exception as e: - # Catch the specific size error and fall back - if "too small" in str(e).lower() or "1024" in str(e): - print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.") - cache_too_small = True - else: - raise e # Reraise if it's a different error (like authentication) - - elif not cache_too_small: - print(f"Loading existing cache: {context_data['cache_id']}") - print("Extending cache TTL by 60 minutes...") - try: - client.caches.update( - name=context_data["cache_id"], - config=types.UpdateCachedContentConfig( - ttl="3600s" - ) - ) - except Exception as e: - print(f"Warning: Failed to update cache TTL. It may have expired. {e}") - - # --------------------------------------------------------- - # GENERATION LOGIC - # --------------------------------------------------------- - if prompt_text: - config_kwargs = { - "max_output_tokens": 65536, - "temperature": 0.0 - } - - # Setup contents array - generation_contents = [] - - if context_data.get("cache_id") and not cache_too_small: - # If we successfully cached, we just pass the cache ID in the config - config_kwargs["cached_content"] = context_data["cache_id"] - else: - # If we didn't cache (or it was too small), pass the files and system instruction directly - generation_contents.extend(file_objects) - config_kwargs["system_instruction"] = system_instruction - - generation_contents.append(prompt_text) - config = types.GenerateContentConfig(**config_kwargs) - - print("Generating response (this may take a moment for large outputs)...") - - response_stream = client.models.generate_content_stream( - model=args.model, - contents=generation_contents, - config=config ) - - if args.output: - with open(args.output, "w") as f: + + cache_too_small = False + file_objects = [] + + if context_data.get("file_ids"): + file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] + + if not context_data.get("cache_id"): + print("Attempting to create Context Cache on Google's servers...") + try: + cache = client.caches.create( + model=args.model, + config=types.CreateCachedContentConfig( + contents=file_objects, + system_instruction=system_instruction, + ttl="3600s" + ) + ) + context_data["cache_id"] = cache.name + if args.context: + with open(args.context, "w") as f: + json.dump(context_data, f, indent=4) + print(f"Context Cache created: {cache.name}") + + except Exception as e: + if "too small" in str(e).lower() or "1024" in str(e): + print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.") + cache_too_small = True + else: + raise e + + elif not cache_too_small: + print(f"Loading existing cache: {context_data['cache_id']}") + print("Extending cache TTL by 60 minutes...") + try: + client.caches.update( + name=context_data["cache_id"], + config=types.UpdateCachedContentConfig(ttl="3600s") + ) + except Exception as e: + print(f"Warning: Failed to update cache TTL. {e}") + + # --------------------------------------------------------- + # GENERATION LOGIC + # --------------------------------------------------------- + if prompt_text: + config_kwargs = { + "max_output_tokens": 65536, + "temperature": 0.0 + } + + generation_contents = [] + + if context_data.get("cache_id") and not cache_too_small: + config_kwargs["cached_content"] = context_data["cache_id"] + else: + generation_contents.extend(file_objects) + config_kwargs["system_instruction"] = system_instruction + + generation_contents.append(prompt_text) + config = types.GenerateContentConfig(**config_kwargs) + + print("Generating response (this may take a moment for large outputs)...") + + response_stream = client.models.generate_content_stream( + model=args.model, + contents=generation_contents, + config=config + ) + + if args.output: + with open(args.output, "w") as f: + for chunk in response_stream: + if chunk.text: + f.write(chunk.text) + f.flush() + print(f"\nDone! Raw output saved directly to {args.output}") + else: + print("-" * 40) for chunk in response_stream: if chunk.text: - f.write(chunk.text) - f.flush() - print(f"\nDone! Raw output saved directly to {args.output}") - else: - print("-" * 40) - for chunk in response_stream: - if chunk.text: - print(chunk.text, end="", flush=True) - print("\n" + "-" * 40) + print(chunk.text, end="", flush=True) + print("\n" + "-" * 40) + + finally: + # --------------------------------------------------------- + # TRANSIENT MODE CLEANUP + # --------------------------------------------------------- + if not args.context and not args.destroy: + print("\n[Transient Mode] Cleaning up resources...") + if context_data.get("cache_id"): + try: + client.caches.delete(name=context_data["cache_id"]) + print(f"Deleted cache: {context_data['cache_id']}") + except Exception as e: + print(f"Warning: Failed to delete cache. {e}", file=sys.stderr) + + for file_id in context_data.get("file_ids", []): + try: + client.files.delete(name=file_id) + print(f"Deleted file: {file_id}") + except Exception as e: + print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr) if __name__ == "__main__": main()