diff --git a/gemini/gemini.py b/gemini/gemini.py index a57d807..98018cf 100644 --- a/gemini/gemini.py +++ b/gemini/gemini.py @@ -103,36 +103,51 @@ def main(): "and output raw data only. If not specific data format is suggested, you can answer with conversational text." ) + cache_too_small = False + file_objects = [] + if context_data.get("file_ids"): + # We always need the file objects just in case caching fails + file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] + if not context_data.get("cache_id"): - print("Creating Context Cache on Google's servers...") + print("Attempting to create Context Cache on Google's servers...") - # Retrieve the file objects - file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] - - # Create the server-side cache (Set to expire in 3600 seconds / 60 minutes) - cache = client.caches.create( - model=args.model, - config=types.CreateCachedContentConfig( - contents=file_objects, - system_instruction=system_instruction, - ttl="3600s" + try: + # Attempt to create the cache + cache = client.caches.create( + model=args.model, + config=types.CreateCachedContentConfig( + contents=file_objects, + system_instruction=system_instruction, + ttl="3600s" + ) ) - ) - context_data["cache_id"] = cache.name - with open(args.context, "w") as f: - json.dump(context_data, f, indent=4) - print(f"Context Cache created: {cache.name}") + context_data["cache_id"] = cache.name + with open(args.context, "w") as f: + json.dump(context_data, f, indent=4) + print(f"Context Cache created: {cache.name}") - else: + except Exception as e: + # Catch the specific size error and fall back + if "too small" in str(e).lower() or "1024" in str(e): + print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.") + cache_too_small = True + else: + raise e # Reraise if it's a different error (like authentication) + + elif not cache_too_small: print(f"Loading existing cache: {context_data['cache_id']}") print("Extending cache TTL by 60 minutes...") - client.caches.update( - name=context_data["cache_id"], - config=types.UpdateCachedContentConfig( - ttl="3600s" + try: + client.caches.update( + name=context_data["cache_id"], + config=types.UpdateCachedContentConfig( + ttl="3600s" + ) ) - ) + except Exception as e: + print(f"Warning: Failed to update cache TTL. It may have expired. {e}") # --------------------------------------------------------- # GENERATION LOGIC @@ -143,21 +158,25 @@ def main(): "temperature": 0.0 } - # The new SDK passes the cache ID directly into the generation config. - # If caching is used, the system_instruction must be in the cache, not here. - if context_data.get("cache_id"): + # Setup contents array + generation_contents = [] + + if context_data.get("cache_id") and not cache_too_small: + # If we successfully cached, we just pass the cache ID in the config config_kwargs["cached_content"] = context_data["cache_id"] else: + # If we didn't cache (or it was too small), pass the files and system instruction directly + generation_contents.extend(file_objects) config_kwargs["system_instruction"] = system_instruction + generation_contents.append(prompt_text) config = types.GenerateContentConfig(**config_kwargs) print("Generating response (this may take a moment for large outputs)...") - # We use stream generation so it writes immediately, avoiding memory bottlenecks response_stream = client.models.generate_content_stream( model=args.model, - contents=prompt_text, + contents=generation_contents, config=config ) @@ -166,7 +185,7 @@ def main(): for chunk in response_stream: if chunk.text: f.write(chunk.text) - f.flush() # Stream direct to the disk in real-time + f.flush() print(f"\nDone! Raw output saved directly to {args.output}") else: print("-" * 40)