Added history to the conversation
This commit is contained in:
+57
-15
@@ -9,13 +9,15 @@ from google.genai import types
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
|
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
|
||||||
parser.add_argument("-c", "--context", type=str, default=None,
|
parser.add_argument("-c", "--context", type=str, default=None,
|
||||||
help="Path to context file. If omitted, files/caches are deleted after execution.")
|
help="Path to context file. If omitted, transient mode is used.")
|
||||||
parser.add_argument("-f", "--files", nargs="+", default=[],
|
parser.add_argument("-f", "--files", nargs="+", default=[],
|
||||||
help="Files to upload to the Gemini API")
|
help="Files to upload to the Gemini API")
|
||||||
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
|
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
|
||||||
help="The model to use (default: gemini-3.1-flash-lite)")
|
help="The model to use (default: gemini-3.1-flash-lite)")
|
||||||
parser.add_argument("-d", "--destroy", action="store_true",
|
parser.add_argument("-d", "--destroy", action="store_true",
|
||||||
help="Destroy the cloud files and cache, and delete the local context file")
|
help="Destroy cloud files/cache, and delete local context")
|
||||||
|
parser.add_argument("-x", "--clear-history", action="store_true",
|
||||||
|
help="Clear the conversation history without destroying files/caches")
|
||||||
parser.add_argument("-o", "--output", type=str,
|
parser.add_argument("-o", "--output", type=str,
|
||||||
help="Direct the raw output to a specific file instead of stdout")
|
help="Direct the raw output to a specific file instead of stdout")
|
||||||
parser.add_argument("-p", "--prompt", type=str,
|
parser.add_argument("-p", "--prompt", type=str,
|
||||||
@@ -24,7 +26,6 @@ def main():
|
|||||||
help="Positional arguments treated as the prompt if -p is omitted")
|
help="Positional arguments treated as the prompt if -p is omitted")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
|
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
|
||||||
|
|
||||||
if not os.environ.get("GEMINI_API_KEY"):
|
if not os.environ.get("GEMINI_API_KEY"):
|
||||||
@@ -33,8 +34,7 @@ def main():
|
|||||||
|
|
||||||
client = genai.Client()
|
client = genai.Client()
|
||||||
|
|
||||||
# Updated structure to track caches and their associated file states by model
|
context_data = {"file_ids": [], "caches": {}, "history": []}
|
||||||
context_data = {"file_ids": [], "caches": {}}
|
|
||||||
|
|
||||||
if args.context and os.path.exists(args.context):
|
if args.context and os.path.exists(args.context):
|
||||||
try:
|
try:
|
||||||
@@ -42,13 +42,22 @@ def main():
|
|||||||
loaded_data = json.load(f)
|
loaded_data = json.load(f)
|
||||||
context_data["file_ids"] = loaded_data.get("file_ids", [])
|
context_data["file_ids"] = loaded_data.get("file_ids", [])
|
||||||
context_data["caches"] = loaded_data.get("caches", {})
|
context_data["caches"] = loaded_data.get("caches", {})
|
||||||
|
context_data["history"] = loaded_data.get("history", [])
|
||||||
# Handle migration from old format
|
|
||||||
if "cache_id" in loaded_data and loaded_data["cache_id"]:
|
|
||||||
print("Warning: Old context format detected. Please use -d to destroy and start fresh, or unexpected cache behavior may occur.", file=sys.stderr)
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
|
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------
|
||||||
|
# CLEAR HISTORY FLAG
|
||||||
|
# ---------------------------------------------------------
|
||||||
|
if args.clear_history:
|
||||||
|
context_data["history"] = []
|
||||||
|
if args.context:
|
||||||
|
with open(args.context, "w") as f:
|
||||||
|
json.dump(context_data, f, indent=4)
|
||||||
|
print("Conversation history cleared.")
|
||||||
|
if not prompt_text and not args.files and not args.destroy:
|
||||||
|
return
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
# DESTROY FLAG LOGIC
|
# DESTROY FLAG LOGIC
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
@@ -172,7 +181,7 @@ def main():
|
|||||||
print(f"Warning: Failed to update cache TTL. {e}")
|
print(f"Warning: Failed to update cache TTL. {e}")
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
# GENERATION LOGIC
|
# GENERATION LOGIC (WITH HISTORY)
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
if prompt_text:
|
if prompt_text:
|
||||||
config_kwargs = {
|
config_kwargs = {
|
||||||
@@ -180,39 +189,72 @@ def main():
|
|||||||
"temperature": 0.0
|
"temperature": 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_contents = []
|
|
||||||
|
|
||||||
if active_cache_id and not cache_too_small:
|
if active_cache_id and not cache_too_small:
|
||||||
config_kwargs["cached_content"] = active_cache_id
|
config_kwargs["cached_content"] = active_cache_id
|
||||||
else:
|
else:
|
||||||
generation_contents.extend(file_objects)
|
|
||||||
config_kwargs["system_instruction"] = system_instruction
|
config_kwargs["system_instruction"] = system_instruction
|
||||||
|
|
||||||
|
# Build the chat payload using strict dictionary representations
|
||||||
|
api_contents = []
|
||||||
|
files_added_to_payload = False
|
||||||
|
|
||||||
|
for msg in context_data.get("history", []):
|
||||||
|
parts = []
|
||||||
|
# If we aren't using a cache, attach un-cached files to the very first user message
|
||||||
|
if msg["role"] == "user" and not files_added_to_payload and not active_cache_id and file_objects:
|
||||||
|
for f in file_objects:
|
||||||
|
parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
|
||||||
|
files_added_to_payload = True
|
||||||
|
|
||||||
|
parts.append({"text": msg["text"]})
|
||||||
|
api_contents.append({"role": msg["role"], "parts": parts})
|
||||||
|
|
||||||
|
# Add the current prompt
|
||||||
|
current_parts = []
|
||||||
|
if not files_added_to_payload and not active_cache_id and file_objects:
|
||||||
|
for f in file_objects:
|
||||||
|
current_parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
|
||||||
|
files_added_to_payload = True
|
||||||
|
|
||||||
generation_contents.append(prompt_text)
|
current_parts.append({"text": prompt_text})
|
||||||
|
api_contents.append({"role": "user", "parts": current_parts})
|
||||||
|
|
||||||
config = types.GenerateContentConfig(**config_kwargs)
|
config = types.GenerateContentConfig(**config_kwargs)
|
||||||
|
|
||||||
print("Generating response (this may take a moment for large outputs)...")
|
print("Generating response (this may take a moment for large outputs)...")
|
||||||
|
|
||||||
response_stream = client.models.generate_content_stream(
|
response_stream = client.models.generate_content_stream(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
contents=generation_contents,
|
contents=api_contents,
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_response_text = ""
|
||||||
|
|
||||||
if args.output:
|
if args.output:
|
||||||
with open(args.output, "w") as f:
|
with open(args.output, "w") as f:
|
||||||
for chunk in response_stream:
|
for chunk in response_stream:
|
||||||
if chunk.text:
|
if chunk.text:
|
||||||
f.write(chunk.text)
|
f.write(chunk.text)
|
||||||
f.flush()
|
f.flush()
|
||||||
|
full_response_text += chunk.text
|
||||||
print(f"\nDone! Raw output saved directly to {args.output}")
|
print(f"\nDone! Raw output saved directly to {args.output}")
|
||||||
else:
|
else:
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
for chunk in response_stream:
|
for chunk in response_stream:
|
||||||
if chunk.text:
|
if chunk.text:
|
||||||
print(chunk.text, end="", flush=True)
|
print(chunk.text, end="", flush=True)
|
||||||
|
full_response_text += chunk.text
|
||||||
print("\n" + "-" * 40)
|
print("\n" + "-" * 40)
|
||||||
|
|
||||||
|
# Append this turn to the local history and save
|
||||||
|
context_data["history"].append({"role": "user", "text": prompt_text})
|
||||||
|
context_data["history"].append({"role": "model", "text": full_response_text})
|
||||||
|
|
||||||
|
if args.context:
|
||||||
|
with open(args.context, "w") as f:
|
||||||
|
json.dump(context_data, f, indent=4)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
# TRANSIENT MODE CLEANUP
|
# TRANSIENT MODE CLEANUP
|
||||||
|
|||||||
Reference in New Issue
Block a user