diff --git a/gemini/gemini.py b/gemini/gemini.py index eb3722c..a57d807 100644 --- a/gemini/gemini.py +++ b/gemini/gemini.py @@ -3,9 +3,8 @@ import argparse import json import os import sys -import datetime -import google.generativeai as genai -from google.generativeai import caching +from google import genai +from google.genai import types def main(): parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching") @@ -13,8 +12,8 @@ def main(): help="Path to the context file (defaults to .gemini_context.json)") parser.add_argument("-f", "--files", nargs="+", default=[], help="Files to upload to the Gemini API") - parser.add_argument("-m", "--model", type=str, default="models/gemini-1.5-pro-001", - help="The model to use (default: models/gemini-1.5-pro-001)") + parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite", + help="The model to use (default: gemini-3.1-flash-lite)") parser.add_argument("-d", "--destroy", action="store_true", help="Destroy the cloud files and cache, and delete the local context file") parser.add_argument("-o", "--output", type=str, @@ -29,12 +28,12 @@ def main(): # Determine the prompt text prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None - # Authenticate - api_key = os.environ.get("GEMINI_API_KEY") - if not api_key: + # Authenticate: The new Client() automatically looks for the GEMINI_API_KEY environment variable. + if not os.environ.get("GEMINI_API_KEY"): print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr) sys.exit(1) - genai.configure(api_key=api_key) + + client = genai.Client() context_data = {"file_ids": [], "cache_id": None} @@ -53,14 +52,14 @@ def main(): print("Destroying server resources and local context...") if context_data.get("cache_id"): try: - caching.CachedContent.get(context_data["cache_id"]).delete() + client.caches.delete(name=context_data["cache_id"]) print(f"Deleted cache: {context_data['cache_id']}") except Exception as e: print(f"Warning: Failed to delete cache. {e}", file=sys.stderr) for file_id in context_data.get("file_ids", []): try: - genai.delete_file(file_id) + client.files.delete(name=file_id) print(f"Deleted file: {file_id}") except Exception as e: print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr) @@ -82,7 +81,7 @@ def main(): continue print(f"Uploading '{file_path}'...") - uploaded_file = genai.upload_file(file_path) + uploaded_file = client.files.upload(file=file_path) print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'") if uploaded_file.name not in context_data["file_ids"]: @@ -94,68 +93,87 @@ def main(): # --------------------------------------------------------- # CACHE CREATION AND TTL EXTENSION LOGIC # --------------------------------------------------------- - cache = None system_instruction = ( - "You are a strict data extraction tool. Output exactly what is requested " - "(e.g., CSV). Never use markdown formatting blocks (like ```csv). " - "Never include conversational text, greetings, or explanations. Output raw data only." +# "You are a strict data extraction tool. Output exactly what is requested " +# "(e.g., CSV). Never use markdown formatting blocks (like ```csv). " +# "Never include conversational text, greetings, or explanations. Output raw data only." + "You are a hybrid data extraction tool. If a specific format or file format output is requested " + "(e.g., CSV), then output exactly what was requested and never use markdown formatting blocks (like ```csv). " + "If a specific format was requested, never include conversational text, greetings, or explanations; " + "and output raw data only. If not specific data format is suggested, you can answer with conversational text." ) if context_data.get("file_ids"): if not context_data.get("cache_id"): print("Creating Context Cache on Google's servers...") - file_objects = [genai.get_file(f_id) for f_id in context_data["file_ids"]] - cache = caching.CachedContent.create( + # Retrieve the file objects + file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] + + # Create the server-side cache (Set to expire in 3600 seconds / 60 minutes) + cache = client.caches.create( model=args.model, - system_instruction=system_instruction, - contents=file_objects, - ttl=datetime.timedelta(minutes=60) + config=types.CreateCachedContentConfig( + contents=file_objects, + system_instruction=system_instruction, + ttl="3600s" + ) ) context_data["cache_id"] = cache.name with open(args.context, "w") as f: json.dump(context_data, f, indent=4) print(f"Context Cache created: {cache.name}") + else: print(f"Loading existing cache: {context_data['cache_id']}") - cache = caching.CachedContent.get(context_data["cache_id"]) print("Extending cache TTL by 60 minutes...") - cache.update(ttl=datetime.timedelta(minutes=60)) + client.caches.update( + name=context_data["cache_id"], + config=types.UpdateCachedContentConfig( + ttl="3600s" + ) + ) # --------------------------------------------------------- # GENERATION LOGIC # --------------------------------------------------------- if prompt_text: - generation_config = { + config_kwargs = { "max_output_tokens": 65536, "temperature": 0.0 } - if cache: - model = genai.GenerativeModel.from_cached_content( - cached_content=cache, - generation_config=generation_config - ) + # The new SDK passes the cache ID directly into the generation config. + # If caching is used, the system_instruction must be in the cache, not here. + if context_data.get("cache_id"): + config_kwargs["cached_content"] = context_data["cache_id"] else: - # Fallback if no files were uploaded - model = genai.GenerativeModel( - model_name=args.model, - system_instruction=system_instruction, - generation_config=generation_config - ) + config_kwargs["system_instruction"] = system_instruction + + config = types.GenerateContentConfig(**config_kwargs) print("Generating response (this may take a moment for large outputs)...") - response = model.generate_content(prompt_text) + # We use stream generation so it writes immediately, avoiding memory bottlenecks + response_stream = client.models.generate_content_stream( + model=args.model, + contents=prompt_text, + config=config + ) if args.output: with open(args.output, "w") as f: - f.write(response.text) - print(f"Done! Raw output saved directly to {args.output}") + for chunk in response_stream: + if chunk.text: + f.write(chunk.text) + f.flush() # Stream direct to the disk in real-time + print(f"\nDone! Raw output saved directly to {args.output}") else: print("-" * 40) - print(response.text) - print("-" * 40) + for chunk in response_stream: + if chunk.text: + print(chunk.text, end="", flush=True) + print("\n" + "-" * 40) if __name__ == "__main__": - main() \ No newline at end of file + main()