Added history to the conversation

This commit is contained in:
2026-06-03 15:55:23 -07:00
parent 5d4184f214
commit 0b1037e4e1
+57 -15
View File
@@ -9,13 +9,15 @@ from google.genai import types
def main(): def main():
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching") parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
parser.add_argument("-c", "--context", type=str, default=None, parser.add_argument("-c", "--context", type=str, default=None,
help="Path to context file. If omitted, files/caches are deleted after execution.") help="Path to context file. If omitted, transient mode is used.")
parser.add_argument("-f", "--files", nargs="+", default=[], parser.add_argument("-f", "--files", nargs="+", default=[],
help="Files to upload to the Gemini API") help="Files to upload to the Gemini API")
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite", parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
help="The model to use (default: gemini-3.1-flash-lite)") help="The model to use (default: gemini-3.1-flash-lite)")
parser.add_argument("-d", "--destroy", action="store_true", parser.add_argument("-d", "--destroy", action="store_true",
help="Destroy the cloud files and cache, and delete the local context file") help="Destroy cloud files/cache, and delete local context")
parser.add_argument("-x", "--clear-history", action="store_true",
help="Clear the conversation history without destroying files/caches")
parser.add_argument("-o", "--output", type=str, parser.add_argument("-o", "--output", type=str,
help="Direct the raw output to a specific file instead of stdout") help="Direct the raw output to a specific file instead of stdout")
parser.add_argument("-p", "--prompt", type=str, parser.add_argument("-p", "--prompt", type=str,
@@ -24,7 +26,6 @@ def main():
help="Positional arguments treated as the prompt if -p is omitted") help="Positional arguments treated as the prompt if -p is omitted")
args = parser.parse_args() args = parser.parse_args()
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
if not os.environ.get("GEMINI_API_KEY"): if not os.environ.get("GEMINI_API_KEY"):
@@ -33,8 +34,7 @@ def main():
client = genai.Client() client = genai.Client()
# Updated structure to track caches and their associated file states by model context_data = {"file_ids": [], "caches": {}, "history": []}
context_data = {"file_ids": [], "caches": {}}
if args.context and os.path.exists(args.context): if args.context and os.path.exists(args.context):
try: try:
@@ -42,13 +42,22 @@ def main():
loaded_data = json.load(f) loaded_data = json.load(f)
context_data["file_ids"] = loaded_data.get("file_ids", []) context_data["file_ids"] = loaded_data.get("file_ids", [])
context_data["caches"] = loaded_data.get("caches", {}) context_data["caches"] = loaded_data.get("caches", {})
context_data["history"] = loaded_data.get("history", [])
# Handle migration from old format
if "cache_id" in loaded_data and loaded_data["cache_id"]:
print("Warning: Old context format detected. Please use -d to destroy and start fresh, or unexpected cache behavior may occur.", file=sys.stderr)
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr) print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
# ---------------------------------------------------------
# CLEAR HISTORY FLAG
# ---------------------------------------------------------
if args.clear_history:
context_data["history"] = []
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
print("Conversation history cleared.")
if not prompt_text and not args.files and not args.destroy:
return
# --------------------------------------------------------- # ---------------------------------------------------------
# DESTROY FLAG LOGIC # DESTROY FLAG LOGIC
# --------------------------------------------------------- # ---------------------------------------------------------
@@ -172,7 +181,7 @@ def main():
print(f"Warning: Failed to update cache TTL. {e}") print(f"Warning: Failed to update cache TTL. {e}")
# --------------------------------------------------------- # ---------------------------------------------------------
# GENERATION LOGIC # GENERATION LOGIC (WITH HISTORY)
# --------------------------------------------------------- # ---------------------------------------------------------
if prompt_text: if prompt_text:
config_kwargs = { config_kwargs = {
@@ -180,39 +189,72 @@ def main():
"temperature": 0.0 "temperature": 0.0
} }
generation_contents = []
if active_cache_id and not cache_too_small: if active_cache_id and not cache_too_small:
config_kwargs["cached_content"] = active_cache_id config_kwargs["cached_content"] = active_cache_id
else: else:
generation_contents.extend(file_objects)
config_kwargs["system_instruction"] = system_instruction config_kwargs["system_instruction"] = system_instruction
# Build the chat payload using strict dictionary representations
api_contents = []
files_added_to_payload = False
for msg in context_data.get("history", []):
parts = []
# If we aren't using a cache, attach un-cached files to the very first user message
if msg["role"] == "user" and not files_added_to_payload and not active_cache_id and file_objects:
for f in file_objects:
parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
files_added_to_payload = True
parts.append({"text": msg["text"]})
api_contents.append({"role": msg["role"], "parts": parts})
# Add the current prompt
current_parts = []
if not files_added_to_payload and not active_cache_id and file_objects:
for f in file_objects:
current_parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
files_added_to_payload = True
generation_contents.append(prompt_text) current_parts.append({"text": prompt_text})
api_contents.append({"role": "user", "parts": current_parts})
config = types.GenerateContentConfig(**config_kwargs) config = types.GenerateContentConfig(**config_kwargs)
print("Generating response (this may take a moment for large outputs)...") print("Generating response (this may take a moment for large outputs)...")
response_stream = client.models.generate_content_stream( response_stream = client.models.generate_content_stream(
model=args.model, model=args.model,
contents=generation_contents, contents=api_contents,
config=config config=config
) )
full_response_text = ""
if args.output: if args.output:
with open(args.output, "w") as f: with open(args.output, "w") as f:
for chunk in response_stream: for chunk in response_stream:
if chunk.text: if chunk.text:
f.write(chunk.text) f.write(chunk.text)
f.flush() f.flush()
full_response_text += chunk.text
print(f"\nDone! Raw output saved directly to {args.output}") print(f"\nDone! Raw output saved directly to {args.output}")
else: else:
print("-" * 40) print("-" * 40)
for chunk in response_stream: for chunk in response_stream:
if chunk.text: if chunk.text:
print(chunk.text, end="", flush=True) print(chunk.text, end="", flush=True)
full_response_text += chunk.text
print("\n" + "-" * 40) print("\n" + "-" * 40)
# Append this turn to the local history and save
context_data["history"].append({"role": "user", "text": prompt_text})
context_data["history"].append({"role": "model", "text": full_response_text})
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
finally: finally:
# --------------------------------------------------------- # ---------------------------------------------------------
# TRANSIENT MODE CLEANUP # TRANSIENT MODE CLEANUP