Working on the gemini CLI

This commit is contained in:
2026-06-03 15:10:24 -07:00
parent d2209fd209
commit 8d9fe389c6
+34 -22
View File
@@ -8,8 +8,8 @@ from google.genai import types
def main(): def main():
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching") parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
parser.add_argument("-c", "--context", type=str, default=".gemini_context.json", parser.add_argument("-c", "--context", type=str, default=None,
help="Path to the context file (defaults to .gemini_context.json)") help="Path to context file. If omitted, files/caches are deleted after execution.")
parser.add_argument("-f", "--files", nargs="+", default=[], parser.add_argument("-f", "--files", nargs="+", default=[],
help="Files to upload to the Gemini API") help="Files to upload to the Gemini API")
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite", parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
@@ -25,20 +25,17 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Determine the prompt text
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
# Authenticate: The new Client() automatically looks for the GEMINI_API_KEY environment variable.
if not os.environ.get("GEMINI_API_KEY"): if not os.environ.get("GEMINI_API_KEY"):
print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr) print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr)
sys.exit(1) sys.exit(1)
client = genai.Client() client = genai.Client()
context_data = {"file_ids": [], "cache_id": None} context_data = {"file_ids": [], "cache_id": None}
# Load existing context if it exists # Load existing context if a context file is specified and exists
if os.path.exists(args.context): if args.context and os.path.exists(args.context):
try: try:
with open(args.context, "r") as f: with open(args.context, "r") as f:
context_data = json.load(f) context_data = json.load(f)
@@ -64,13 +61,15 @@ def main():
except Exception as e: except Exception as e:
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr) print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
if os.path.exists(args.context): if args.context and os.path.exists(args.context):
os.remove(args.context) os.remove(args.context)
print(f"Deleted local context file: {args.context}") print(f"Deleted local context file: {args.context}")
print("Cleanup complete.") print("Cleanup complete.")
return # Exit after destroying return
# Wrap the main execution in a try/finally to guarantee cleanup on transient runs
try:
# --------------------------------------------------------- # ---------------------------------------------------------
# UPLOAD LOGIC # UPLOAD LOGIC
# --------------------------------------------------------- # ---------------------------------------------------------
@@ -87,11 +86,12 @@ def main():
if uploaded_file.name not in context_data["file_ids"]: if uploaded_file.name not in context_data["file_ids"]:
context_data["file_ids"].append(uploaded_file.name) context_data["file_ids"].append(uploaded_file.name)
if args.context:
with open(args.context, "w") as f: with open(args.context, "w") as f:
json.dump(context_data, f, indent=4) json.dump(context_data, f, indent=4)
# --------------------------------------------------------- # ---------------------------------------------------------
# CACHE CREATION AND TTL EXTENSION LOGIC # CACHE CREATION LOGIC
# --------------------------------------------------------- # ---------------------------------------------------------
system_instruction = ( system_instruction = (
# "You are a strict data extraction tool. Output exactly what is requested " # "You are a strict data extraction tool. Output exactly what is requested "
@@ -107,14 +107,11 @@ def main():
file_objects = [] file_objects = []
if context_data.get("file_ids"): if context_data.get("file_ids"):
# We always need the file objects just in case caching fails
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
if not context_data.get("cache_id"): if not context_data.get("cache_id"):
print("Attempting to create Context Cache on Google's servers...") print("Attempting to create Context Cache on Google's servers...")
try: try:
# Attempt to create the cache
cache = client.caches.create( cache = client.caches.create(
model=args.model, model=args.model,
config=types.CreateCachedContentConfig( config=types.CreateCachedContentConfig(
@@ -124,17 +121,17 @@ def main():
) )
) )
context_data["cache_id"] = cache.name context_data["cache_id"] = cache.name
if args.context:
with open(args.context, "w") as f: with open(args.context, "w") as f:
json.dump(context_data, f, indent=4) json.dump(context_data, f, indent=4)
print(f"Context Cache created: {cache.name}") print(f"Context Cache created: {cache.name}")
except Exception as e: except Exception as e:
# Catch the specific size error and fall back
if "too small" in str(e).lower() or "1024" in str(e): if "too small" in str(e).lower() or "1024" in str(e):
print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.") print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.")
cache_too_small = True cache_too_small = True
else: else:
raise e # Reraise if it's a different error (like authentication) raise e
elif not cache_too_small: elif not cache_too_small:
print(f"Loading existing cache: {context_data['cache_id']}") print(f"Loading existing cache: {context_data['cache_id']}")
@@ -142,12 +139,10 @@ def main():
try: try:
client.caches.update( client.caches.update(
name=context_data["cache_id"], name=context_data["cache_id"],
config=types.UpdateCachedContentConfig( config=types.UpdateCachedContentConfig(ttl="3600s")
ttl="3600s"
)
) )
except Exception as e: except Exception as e:
print(f"Warning: Failed to update cache TTL. It may have expired. {e}") print(f"Warning: Failed to update cache TTL. {e}")
# --------------------------------------------------------- # ---------------------------------------------------------
# GENERATION LOGIC # GENERATION LOGIC
@@ -158,14 +153,11 @@ def main():
"temperature": 0.0 "temperature": 0.0
} }
# Setup contents array
generation_contents = [] generation_contents = []
if context_data.get("cache_id") and not cache_too_small: if context_data.get("cache_id") and not cache_too_small:
# If we successfully cached, we just pass the cache ID in the config
config_kwargs["cached_content"] = context_data["cache_id"] config_kwargs["cached_content"] = context_data["cache_id"]
else: else:
# If we didn't cache (or it was too small), pass the files and system instruction directly
generation_contents.extend(file_objects) generation_contents.extend(file_objects)
config_kwargs["system_instruction"] = system_instruction config_kwargs["system_instruction"] = system_instruction
@@ -194,5 +186,25 @@ def main():
print(chunk.text, end="", flush=True) print(chunk.text, end="", flush=True)
print("\n" + "-" * 40) print("\n" + "-" * 40)
finally:
# ---------------------------------------------------------
# TRANSIENT MODE CLEANUP
# ---------------------------------------------------------
if not args.context and not args.destroy:
print("\n[Transient Mode] Cleaning up resources...")
if context_data.get("cache_id"):
try:
client.caches.delete(name=context_data["cache_id"])
print(f"Deleted cache: {context_data['cache_id']}")
except Exception as e:
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
for file_id in context_data.get("file_ids", []):
try:
client.files.delete(name=file_id)
print(f"Deleted file: {file_id}")
except Exception as e:
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
if __name__ == "__main__": if __name__ == "__main__":
main() main()