From 2da7ea8aa5e323c373fc79f192b1fd4853037b2e Mon Sep 17 00:00:00 2001 From: athikha-faiz_infosys Date: Wed, 15 Oct 2025 16:12:52 +0530 Subject: [PATCH 1/7] enhancement #13508 --- simulated_annealing/README.md | 24 ++++ simulated_annealing/__init__.py | 10 ++ simulated_annealing/example.py | 33 +++++ simulated_annealing/gui.py | 157 +++++++++++++++++++++ simulated_annealing/simulated_annealing.py | 119 ++++++++++++++++ 5 files changed, 343 insertions(+) create mode 100644 simulated_annealing/README.md create mode 100644 simulated_annealing/__init__.py create mode 100644 simulated_annealing/example.py create mode 100644 simulated_annealing/gui.py create mode 100644 simulated_annealing/simulated_annealing.py diff --git a/simulated_annealing/README.md b/simulated_annealing/README.md new file mode 100644 index 000000000000..840feae3fa8c --- /dev/null +++ b/simulated_annealing/README.md @@ -0,0 +1,24 @@ +Simulated Annealing +=================== + +This package provides a simple Simulated Annealing optimizer and a Tkinter GUI to explore parameters and visualize optimization progress. + +Files +- `simulated_annealing.py` - core optimizer (class `SimulatedAnnealing`) +- `example.py` - example functions and CLI demo +- `gui.py` - Tkinter GUI with embedded matplotlib plot + +Quick start +----------- + +Run the GUI: + +```bash +python -m simulated_annealing.gui +``` + +Run the CLI example: + +```bash +python -m simulated_annealing.example +``` diff --git a/simulated_annealing/__init__.py b/simulated_annealing/__init__.py new file mode 100644 index 000000000000..bfaf1d331c90 --- /dev/null +++ b/simulated_annealing/__init__.py @@ -0,0 +1,10 @@ +"""Simulated Annealing package + +Exports: +- SimulatedAnnealing: core optimizer class +- example_functions: a small collection of test functions +""" +from .simulated_annealing import SimulatedAnnealing +from .example import example_functions + +__all__ = ["SimulatedAnnealing", "example_functions"] diff --git a/simulated_annealing/example.py b/simulated_annealing/example.py new file mode 100644 index 000000000000..c3fc8fed630a --- /dev/null +++ b/simulated_annealing/example.py @@ -0,0 +1,33 @@ +from typing import Callable, Dict, Sequence + + +def sphere(x: Sequence[float]) -> float: + return sum(v * v for v in x) + + +def rastrigin(x: Sequence[float]) -> float: + # Rastrigin function (common test function) + A = 10 + return A * len(x) + sum((v * v - A * __import__("math").cos(2 * __import__("math").pi * v)) for v in x) + + +example_functions: Dict[str, Callable[[Sequence[float]], float]] = { + "sphere": sphere, + "rastrigin": rastrigin, +} + + +def cli_example(): + # CLI demo minimizing 2D sphere + from .simulated_annealing import SimulatedAnnealing + func = sphere + initial = [5.0, -3.0] + bounds = [(-10, 10), (-10, 10)] + sa = SimulatedAnnealing(func, initial, bounds=bounds, temperature=50, cooling_rate=0.95, iterations_per_temp=200) + best, cost, history = sa.optimize() + print("Best:", best) + print("Cost:", cost) + + +if __name__ == "__main__": + cli_example() diff --git a/simulated_annealing/gui.py b/simulated_annealing/gui.py new file mode 100644 index 000000000000..9bc32598481c --- /dev/null +++ b/simulated_annealing/gui.py @@ -0,0 +1,157 @@ +import threading +import tkinter as tk +from tkinter import ttk, messagebox +from typing import Optional + +import matplotlib +matplotlib.use("TkAgg") +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg +import matplotlib.pyplot as plt + +from .simulated_annealing import SimulatedAnnealing +from .example import example_functions + + +class SA_GUI(tk.Tk): + def __init__(self): + super().__init__() + self.title("Simulated Annealing Explorer") + self.geometry("800x600") + + # Left: controls + ctrl = ttk.Frame(self) + ctrl.pack(side=tk.LEFT, fill=tk.Y, padx=8, pady=8) + + ttk.Label(ctrl, text="Function:").pack(anchor=tk.W) + self.func_var = tk.StringVar(value="sphere") + func_menu = ttk.Combobox(ctrl, textvariable=self.func_var, values=list(example_functions.keys()), state="readonly") + func_menu.pack(fill=tk.X) + + ttk.Label(ctrl, text="Initial (comma-separated)").pack(anchor=tk.W, pady=(8, 0)) + self.init_entry = ttk.Entry(ctrl) + self.init_entry.insert(0, "5, -3") + self.init_entry.pack(fill=tk.X) + + ttk.Label(ctrl, text="Bounds (lo:hi comma-separated for each)").pack(anchor=tk.W, pady=(8, 0)) + self.bounds_entry = ttk.Entry(ctrl) + self.bounds_entry.insert(0, "-10:10, -10:10") + self.bounds_entry.pack(fill=tk.X) + + ttk.Label(ctrl, text="Temperature").pack(anchor=tk.W, pady=(8, 0)) + self.temp_entry = ttk.Entry(ctrl) + self.temp_entry.insert(0, "50") + self.temp_entry.pack(fill=tk.X) + + ttk.Label(ctrl, text="Cooling rate").pack(anchor=tk.W, pady=(8, 0)) + self.cool_entry = ttk.Entry(ctrl) + self.cool_entry.insert(0, "0.95") + self.cool_entry.pack(fill=tk.X) + + ttk.Label(ctrl, text="Iterations per temp").pack(anchor=tk.W, pady=(8, 0)) + self.iter_entry = ttk.Entry(ctrl) + self.iter_entry.insert(0, "200") + self.iter_entry.pack(fill=tk.X) + + self.run_btn = ttk.Button(ctrl, text="Run", command=self._on_run) + self.run_btn.pack(fill=tk.X, pady=(12, 0)) + + self.stop_flag = threading.Event() + self.stop_btn = ttk.Button(ctrl, text="Stop", command=self._on_stop, state=tk.DISABLED) + self.stop_btn.pack(fill=tk.X, pady=(6, 0)) + + # Right: plot + fig, self.ax = plt.subplots(figsize=(5, 4)) + self.fig = fig + self.canvas = FigureCanvasTkAgg(fig, master=self) + self.canvas.get_tk_widget().pack(side=tk.RIGHT, fill=tk.BOTH, expand=1) + + self._plot_line, = self.ax.plot([], [], label="best_cost") + self.ax.set_xlabel("Iterations") + self.ax.set_ylabel("Best cost") + self.ax.grid(True) + + def _parse_initial(self) -> list: + raw = self.init_entry.get().strip() + parts = [p.strip() for p in raw.split(",") if p.strip()] + return [float(p) for p in parts] + + def _parse_bounds(self, dim: int): + raw = self.bounds_entry.get().strip() + parts = [p.strip() for p in raw.split(",") if p.strip()] + bounds = [] + for p in parts: + if ":" in p: + lo, hi = p.split(":", 1) + bounds.append((float(lo), float(hi))) + else: + # single number -> symmetric + val = float(p) + bounds.append((-abs(val), abs(val))) + # if fewer provided, extend with wide bounds + while len(bounds) < dim: + bounds.append((-1e6, 1e6)) + return bounds[:dim] + + def _on_run(self): + try: + initial = self._parse_initial() + except Exception as e: + messagebox.showerror("Input error", f"Invalid initial: {e}") + return + + func_name = self.func_var.get() + func = example_functions.get(func_name) + if func is None: + messagebox.showerror("Input error", "Unknown function") + return + + try: + temp = float(self.temp_entry.get()) + cooling = float(self.cool_entry.get()) + iterations = int(self.iter_entry.get()) + except Exception as e: + messagebox.showerror("Input error", f"Invalid numeric param: {e}") + return + + bounds = self._parse_bounds(len(initial)) + + self.run_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.stop_flag.clear() + + def worker(): + sa = SimulatedAnnealing(func, initial, bounds=bounds, temperature=temp, cooling_rate=cooling, iterations_per_temp=iterations) + best, cost, history = sa.optimize() + # update plot on main thread + self.after(0, lambda: self._on_complete(best, cost, history)) + + t = threading.Thread(target=worker, daemon=True) + t.start() + + def _on_stop(self): + # currently we don't have a cooperative stop in the algorithm; inform user + messagebox.showinfo("Stop", "Stop requested, but immediate stop is not implemented. The run will finish current loop.") + + def _on_complete(self, best, cost, history): + x = list(range(len(history.get("best_costs", [])))) + y = history.get("best_costs", []) + self.ax.clear() + self.ax.plot(x, y, label="best_cost") + self.ax.set_xlabel("Iterations") + self.ax.set_ylabel("Best cost") + self.ax.grid(True) + self.ax.legend() + self.canvas.draw() + + messagebox.showinfo("Done", f"Best cost: {cost:.6g}\nBest solution: {best}") + self.run_btn.config(state=tk.NORMAL) + self.stop_btn.config(state=tk.DISABLED) + + +def main(): + app = SA_GUI() + app.mainloop() + + +if __name__ == "__main__": + main() diff --git a/simulated_annealing/simulated_annealing.py b/simulated_annealing/simulated_annealing.py new file mode 100644 index 000000000000..acc966df78ee --- /dev/null +++ b/simulated_annealing/simulated_annealing.py @@ -0,0 +1,119 @@ +import math +import random +from typing import Callable, Sequence, Tuple, List, Optional + + +class SimulatedAnnealing: + """Generic Simulated Annealing optimizer for continuous domains. + + Usage: + sa = SimulatedAnnealing(func, initial_solution, bounds=..., **params) + best, best_cost, history = sa.optimize() + """ + + def __init__( + self, + func: Callable[[Sequence[float]], float], + initial_solution: Sequence[float], + bounds: Optional[Sequence[Tuple[float, float]]] = None, + temperature: float = 100.0, + cooling_rate: float = 0.99, + min_temperature: float = 1e-3, + iterations_per_temp: int = 100, + neighbor_scale: float = 0.1, + seed: Optional[int] = None, + ): + self.func = func + self.current = list(initial_solution) + self.dim = len(initial_solution) + self.bounds = bounds + self.temperature = float(temperature) + self.initial_temperature = float(temperature) + self.cooling_rate = float(cooling_rate) + self.min_temperature = float(min_temperature) + self.iterations_per_temp = int(iterations_per_temp) + self.neighbor_scale = float(neighbor_scale) + if seed is not None: + random.seed(seed) + + def _clip(self, x: float, i: int) -> float: + if not self.bounds: + return x + lo, hi = self.bounds[i] + return max(lo, min(hi, x)) + + def _neighbor(self, solution: Sequence[float]) -> List[float]: + # Gaussian perturbation scaled by neighbor_scale and variable range + new = [] + for i, v in enumerate(solution): + scale = self.neighbor_scale + if self.bounds: + lo, hi = self.bounds[i] + rng = (hi - lo) if hi > lo else 1.0 + scale = self.neighbor_scale * rng + candidate = v + random.gauss(0, scale) + candidate = self._clip(candidate, i) + new.append(candidate) + return new + + def _accept(self, delta: float, temp: float) -> bool: + if delta < 0: + return True + try: + prob = math.exp(-delta / temp) + except OverflowError: + prob = 0.0 + return random.random() < prob + + def optimize(self, max_steps: Optional[int] = None) -> Tuple[List[float], float, dict]: + """Run optimization and return best solution, cost, and history. + + history dict contains: temps, best_costs, current_costs + """ + temp = self.temperature + current = list(self.current) + current_cost = float(self.func(current)) + best = list(current) + best_cost = current_cost + + history = {"temps": [], "best_costs": [], "current_costs": []} + + steps = 0 + while temp > self.min_temperature: + for _ in range(self.iterations_per_temp): + candidate = self._neighbor(current) + candidate_cost = float(self.func(candidate)) + delta = candidate_cost - current_cost + if self._accept(delta, temp): + current = candidate + current_cost = candidate_cost + if current_cost < best_cost: + best = list(current) + best_cost = current_cost + + history["temps"].append(temp) + history["best_costs"].append(best_cost) + history["current_costs"].append(current_cost) + + steps += 1 + if max_steps is not None and steps >= max_steps: + self.current = current + return best, best_cost, history + + # Cool down + temp *= self.cooling_rate + + self.current = current + return best, best_cost, history + + +def _test_quadratic(): + # Simple test: minimize f(x) = (x-3)^2 + func = lambda x: (x[0] - 3) ** 2 + sa = SimulatedAnnealing(func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=50) + best, cost, hist = sa.optimize() + print("best:", best, "cost:", cost) + + +if __name__ == "__main__": + _test_quadratic() From fe2d82190e5f331f23213da8a65b4318995ec565 Mon Sep 17 00:00:00 2001 From: athikha-faiz_infosys Date: Wed, 15 Oct 2025 16:20:14 +0530 Subject: [PATCH 2/7] Add simulated annealing package with GUI, TSP helper, and unit tests --- simulated_annealing/example.py | 23 ++++++++++ simulated_annealing/gui.py | 33 ++++++++++++-- simulated_annealing/simulated_annealing.py | 18 +++++++- simulated_annealing/tsp.py | 53 ++++++++++++++++++++++ tests/test_simulated_annealing.py | 28 ++++++++++++ tests/test_tsp.py | 20 ++++++++ 6 files changed, 170 insertions(+), 5 deletions(-) create mode 100644 simulated_annealing/tsp.py create mode 100644 tests/test_simulated_annealing.py create mode 100644 tests/test_tsp.py diff --git a/simulated_annealing/example.py b/simulated_annealing/example.py index c3fc8fed630a..4adb6de93943 100644 --- a/simulated_annealing/example.py +++ b/simulated_annealing/example.py @@ -29,5 +29,28 @@ def cli_example(): print("Cost:", cost) +def tsp_example(): + # Small TSP demo + from .simulated_annealing import SimulatedAnnealing + from .tsp import make_tsp_cost, random_tour, vector_to_tour + + coords = [(0, 0), (1, 5), (5, 4), (6, 1), (3, -2)] + n = len(coords) + init_tour = random_tour(n) + # represent tour as vector by using tour indices as values (so ranking recovers order) + initial = [float(i) for i in init_tour] + cost_fn = make_tsp_cost(coords) + + sa = SimulatedAnnealing(cost_fn, initial, temperature=100.0, cooling_rate=0.995, iterations_per_temp=500) + best, cost, history = sa.optimize() + best_tour = vector_to_tour(best) + print("Best tour:", best_tour) + print("Cost:", cost) + + if __name__ == "__main__": + # Run CLI examples + print("Running continuous example...") cli_example() + print("Running TSP example...") + tsp_example() diff --git a/simulated_annealing/gui.py b/simulated_annealing/gui.py index 9bc32598481c..a0ce4f1454f1 100644 --- a/simulated_annealing/gui.py +++ b/simulated_annealing/gui.py @@ -121,16 +121,24 @@ def _on_run(self): def worker(): sa = SimulatedAnnealing(func, initial, bounds=bounds, temperature=temp, cooling_rate=cooling, iterations_per_temp=iterations) - best, cost, history = sa.optimize() - # update plot on main thread + + def progress_cb(step, best_cost, current_cost): + # schedule a plot update on the main thread + self.after(0, lambda: self._update_plot_partial(step, best_cost)) + + best, cost, history = sa.optimize(stop_event=self.stop_flag, progress_callback=progress_cb) + # update final plot on main thread self.after(0, lambda: self._on_complete(best, cost, history)) t = threading.Thread(target=worker, daemon=True) t.start() def _on_stop(self): - # currently we don't have a cooperative stop in the algorithm; inform user - messagebox.showinfo("Stop", "Stop requested, but immediate stop is not implemented. The run will finish current loop.") + # set stop flag; optimizer will stop cooperatively + self.stop_flag.set() + self.stop_btn.config(state=tk.DISABLED) + self.run_btn.config(state=tk.NORMAL) + messagebox.showinfo("Stop", "Stop requested; optimizer will stop shortly.") def _on_complete(self, best, cost, history): x = list(range(len(history.get("best_costs", [])))) @@ -147,6 +155,23 @@ def _on_complete(self, best, cost, history): self.run_btn.config(state=tk.NORMAL) self.stop_btn.config(state=tk.DISABLED) + def _update_plot_partial(self, step: int, best_cost: float): + # Append a new point to plot (x=step, y=best_cost) + # We'll redraw full plot for simplicity + line_x = list(range(len(self.ax.lines[0].get_xdata()) + 1)) if self.ax.lines else [step] + if self.ax.lines: + ydata = list(self.ax.lines[0].get_ydata()) + ydata.append(best_cost) + else: + ydata = [best_cost] + self.ax.clear() + self.ax.plot(line_x, ydata, label="best_cost") + self.ax.set_xlabel("Iterations") + self.ax.set_ylabel("Best cost") + self.ax.grid(True) + self.ax.legend() + self.canvas.draw() + def main(): app = SA_GUI() diff --git a/simulated_annealing/simulated_annealing.py b/simulated_annealing/simulated_annealing.py index acc966df78ee..1ad354dde574 100644 --- a/simulated_annealing/simulated_annealing.py +++ b/simulated_annealing/simulated_annealing.py @@ -65,9 +65,13 @@ def _accept(self, delta: float, temp: float) -> bool: prob = 0.0 return random.random() < prob - def optimize(self, max_steps: Optional[int] = None) -> Tuple[List[float], float, dict]: + def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object] = None, progress_callback: Optional[Callable[[int, float, float], None]] = None) -> Tuple[List[float], float, dict]: """Run optimization and return best solution, cost, and history. + New optional args: + - stop_event: a threading.Event-like object. If set, optimization stops early. + - progress_callback: callable(step, best_cost, current_cost) called periodically. + history dict contains: temps, best_costs, current_costs """ temp = self.temperature @@ -81,6 +85,11 @@ def optimize(self, max_steps: Optional[int] = None) -> Tuple[List[float], float, steps = 0 while temp > self.min_temperature: for _ in range(self.iterations_per_temp): + # Check stop event + if stop_event is not None and getattr(stop_event, "is_set", lambda: False)(): + self.current = current + return best, best_cost, history + candidate = self._neighbor(current) candidate_cost = float(self.func(candidate)) delta = candidate_cost - current_cost @@ -96,6 +105,13 @@ def optimize(self, max_steps: Optional[int] = None) -> Tuple[List[float], float, history["current_costs"].append(current_cost) steps += 1 + if progress_callback is not None and steps % max(1, self.iterations_per_temp // 10) == 0: + try: + progress_callback(steps, best_cost, current_cost) + except Exception: + # Don't let callback errors stop optimization + pass + if max_steps is not None and steps >= max_steps: self.current = current return best, best_cost, history diff --git a/simulated_annealing/tsp.py b/simulated_annealing/tsp.py new file mode 100644 index 000000000000..703f7af66517 --- /dev/null +++ b/simulated_annealing/tsp.py @@ -0,0 +1,53 @@ +import math +import random +from typing import List, Sequence, Tuple + + +def euclidean_distance(a: Sequence[float], b: Sequence[float]) -> float: + return math.hypot(a[0] - b[0], a[1] - b[1]) + + +def total_distance(tour: Sequence[int], coords: Sequence[Tuple[float, float]]) -> float: + d = 0.0 + n = len(tour) + for i in range(n): + a = coords[tour[i]] + b = coords[tour[(i + 1) % n]] + d += euclidean_distance(a, b) + return d + + +def random_tour(n: int) -> List[int]: + tour = list(range(n)) + random.shuffle(tour) + return tour + + +def neighbor_swap(tour: Sequence[int]) -> List[int]: + # swap two indices + n = len(tour) + i, j = random.sample(range(n), 2) + new = list(tour) + new[i], new[j] = new[j], new[i] + return new + + +def tour_to_vector(tour: Sequence[int]) -> List[float]: + # Convert permutation to a float vector for generic optimizer compatibility + return [float(i) for i in tour] + + +def vector_to_tour(vec: Sequence[float]) -> List[int]: + # Convert vector of floats back to a tour by ranking + pairs = list(enumerate(vec)) + pairs.sort(key=lambda p: p[1]) + return [int(p[0]) for p in pairs] + + +def make_tsp_cost(coords: Sequence[Tuple[float, float]]): + def cost_from_vector(vec: Sequence[float]) -> float: + # Convert vector to tour and compute total distance + tour = vector_to_tour(vec) + return total_distance(tour, coords) + + return cost_from_vector diff --git a/tests/test_simulated_annealing.py b/tests/test_simulated_annealing.py new file mode 100644 index 000000000000..76a912df1a07 --- /dev/null +++ b/tests/test_simulated_annealing.py @@ -0,0 +1,28 @@ +import unittest +from simulated_annealing.simulated_annealing import SimulatedAnnealing + + +def sphere(x): + return sum(v * v for v in x) + + +class TestSimulatedAnnealing(unittest.TestCase): + def test_minimize_sphere_1d(self): + sa = SimulatedAnnealing(sphere, [5.0], bounds=[(-10, 10)], temperature=10, cooling_rate=0.9, iterations_per_temp=50) + best, cost, hist = sa.optimize() + # Best should be near 0 with tiny cost + self.assertLess(cost, 1e-2) + + def test_stop_event(self): + import threading + stop = threading.Event() + sa = SimulatedAnnealing(sphere, [5.0], bounds=[(-10, 10)], temperature=10, cooling_rate=0.9, iterations_per_temp=1000) + # request stop immediately + stop.set() + best, cost, hist = sa.optimize(stop_event=stop) + # Should return without error + self.assertIsNotNone(best) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_tsp.py b/tests/test_tsp.py new file mode 100644 index 000000000000..a9ac0b9f0dcb --- /dev/null +++ b/tests/test_tsp.py @@ -0,0 +1,20 @@ +import unittest +from simulated_annealing.tsp import total_distance, random_tour, euclidean_distance + + +class TestTSP(unittest.TestCase): + def test_distance_symmetric(self): + a = (0, 0) + b = (3, 4) + d = euclidean_distance(a, b) + self.assertAlmostEqual(d, 5.0) + + def test_total_distance_cycle(self): + coords = [(0, 0), (0, 1), (1, 1), (1, 0)] + tour = [0, 1, 2, 3] + d = total_distance(tour, coords) + self.assertGreater(d, 0) + + +if __name__ == '__main__': + unittest.main() From e4422cd9e268d19fe81809bc8e6a27c1a0ebedac Mon Sep 17 00:00:00 2001 From: athikha-faiz_infosys Date: Thu, 16 Oct 2025 13:28:19 +0530 Subject: [PATCH 3/7] chore: move simple data-structures into data_structures.simple and remove ds_extra --- data_structures/__init__.py | 9 ++ data_structures/simple/__init__.py | 26 ++++++ data_structures/simple/bst.py | 46 ++++++++++ data_structures/simple/graph.py | 41 +++++++++ data_structures/simple/linked_list.py | 56 ++++++++++++ data_structures/simple/min_heap.py | 16 ++++ data_structures/simple/queue.py | 16 ++++ data_structures/simple/run_ds_tests.py | 119 +++++++++++++++++++++++++ data_structures/simple/stack.py | 34 +++++++ data_structures/simple/trie.py | 34 +++++++ data_structures/simple/union_find.py | 30 +++++++ 11 files changed, 427 insertions(+) create mode 100644 data_structures/simple/__init__.py create mode 100644 data_structures/simple/bst.py create mode 100644 data_structures/simple/graph.py create mode 100644 data_structures/simple/linked_list.py create mode 100644 data_structures/simple/min_heap.py create mode 100644 data_structures/simple/queue.py create mode 100644 data_structures/simple/run_ds_tests.py create mode 100644 data_structures/simple/stack.py create mode 100644 data_structures/simple/trie.py create mode 100644 data_structures/simple/union_find.py diff --git a/data_structures/__init__.py b/data_structures/__init__.py index e69de29bb2d1..ff5f4a759edc 100644 --- a/data_structures/__init__.py +++ b/data_structures/__init__.py @@ -0,0 +1,9 @@ +"""Top-level data_structures package. + +This file exposes a small `simple` collection of educational data structures +under `data_structures.simple` and provides a few convenience imports. +""" + +from . import simple + +__all__ = ["simple"] diff --git a/data_structures/simple/__init__.py b/data_structures/simple/__init__.py new file mode 100644 index 000000000000..b50ec0f5931e --- /dev/null +++ b/data_structures/simple/__init__.py @@ -0,0 +1,26 @@ +"""Simple implementations of common data structures. + +This subpackage contains small, educational implementations copied from the +`ds_extra` package so they can be accessed via `data_structures.simple`. +""" + +from .linked_list import LinkedList +from .stack import ListStack, LinkedStack +from .queue import Queue +from .bst import BinarySearchTree +from .graph import Graph +from .trie import Trie +from .union_find import UnionFind +from .min_heap import MinHeap + +__all__ = [ + "LinkedList", + "ListStack", + "LinkedStack", + "Queue", + "BinarySearchTree", + "Graph", + "Trie", + "UnionFind", + "MinHeap", +] diff --git a/data_structures/simple/bst.py b/data_structures/simple/bst.py new file mode 100644 index 000000000000..97da25e54505 --- /dev/null +++ b/data_structures/simple/bst.py @@ -0,0 +1,46 @@ +from typing import Any, Optional, List + + +class BSTNode: + def __init__(self, key: Any): + self.key = key + self.left: Optional[BSTNode] = None + self.right: Optional[BSTNode] = None + + +class BinarySearchTree: + def __init__(self): + self.root: Optional[BSTNode] = None + + def insert(self, key: Any) -> None: + def _insert(node: Optional[BSTNode], key: Any) -> BSTNode: + if not node: + return BSTNode(key) + if key < node.key: + node.left = _insert(node.left, key) + else: + node.right = _insert(node.right, key) + return node + + self.root = _insert(self.root, key) + + def search(self, key: Any) -> bool: + node = self.root + while node: + if key == node.key: + return True + node = node.left if key < node.key else node.right + return False + + def inorder(self) -> List[Any]: + res: List[Any] = [] + + def _in(node: Optional[BSTNode]) -> None: + if not node: + return + _in(node.left) + res.append(node.key) + _in(node.right) + + _in(self.root) + return res diff --git a/data_structures/simple/graph.py b/data_structures/simple/graph.py new file mode 100644 index 000000000000..6eccebc7bb9e --- /dev/null +++ b/data_structures/simple/graph.py @@ -0,0 +1,41 @@ +from collections import defaultdict, deque +from typing import Any, Dict, List + + +class Graph: + def __init__(self, directed: bool = False): + self.adj: Dict[Any, List[Any]] = defaultdict(list) + self.directed = directed + + def add_edge(self, u: Any, v: Any) -> None: + self.adj[u].append(v) + if not self.directed: + self.adj[v].append(u) + + def bfs(self, start: Any) -> List[Any]: + visited = set() + order = [] + q = deque([start]) + visited.add(start) + while q: + u = q.popleft() + order.append(u) + for v in self.adj[u]: + if v not in visited: + visited.add(v) + q.append(v) + return order + + def dfs(self, start: Any) -> List[Any]: + visited = set() + order = [] + + def _dfs(u: Any) -> None: + visited.add(u) + order.append(u) + for v in self.adj[u]: + if v not in visited: + _dfs(v) + + _dfs(start) + return order diff --git a/data_structures/simple/linked_list.py b/data_structures/simple/linked_list.py new file mode 100644 index 000000000000..7558c7421bd4 --- /dev/null +++ b/data_structures/simple/linked_list.py @@ -0,0 +1,56 @@ +from typing import Any, Optional + + +class Node: + def __init__(self, value: Any): + self.value = value + self.next: Optional[Node] = None + + +class LinkedList: + def __init__(self): + self.head: Optional[Node] = None + + def append(self, value: Any) -> None: + node = Node(value) + if not self.head: + self.head = node + return + cur = self.head + while cur.next: + cur = cur.next + cur.next = node + + def prepend(self, value: Any) -> None: + node = Node(value) + node.next = self.head + self.head = node + + def find(self, value: Any) -> Optional[Node]: + cur = self.head + while cur: + if cur.value == value: + return cur + cur = cur.next + return None + + def delete(self, value: Any) -> bool: + cur = self.head + prev = None + while cur: + if cur.value == value: + if prev: + prev.next = cur.next + else: + self.head = cur.next + return True + prev, cur = cur, cur.next + return False + + def to_list(self) -> list: + out = [] + cur = self.head + while cur: + out.append(cur.value) + cur = cur.next + return out diff --git a/data_structures/simple/min_heap.py b/data_structures/simple/min_heap.py new file mode 100644 index 000000000000..a55f18ed8775 --- /dev/null +++ b/data_structures/simple/min_heap.py @@ -0,0 +1,16 @@ +import heapq +from typing import Any, List, Optional + + +class MinHeap: + def __init__(self): + self._data: List[Any] = [] + + def push(self, v: Any) -> None: + heapq.heappush(self._data, v) + + def pop(self) -> Any: + return heapq.heappop(self._data) + + def peek(self) -> Optional[Any]: + return self._data[0] if self._data else None diff --git a/data_structures/simple/queue.py b/data_structures/simple/queue.py new file mode 100644 index 000000000000..446fa0cf3a9f --- /dev/null +++ b/data_structures/simple/queue.py @@ -0,0 +1,16 @@ +from collections import deque +from typing import Any, Optional + + +class Queue: + def __init__(self): + self._dq = deque() + + def enqueue(self, v: Any) -> None: + self._dq.append(v) + + def dequeue(self) -> Any: + return self._dq.popleft() + + def peek(self) -> Optional[Any]: + return self._dq[0] if self._dq else None diff --git a/data_structures/simple/run_ds_tests.py b/data_structures/simple/run_ds_tests.py new file mode 100644 index 000000000000..67aa19b36b47 --- /dev/null +++ b/data_structures/simple/run_ds_tests.py @@ -0,0 +1,119 @@ +"""Smoke test runner for the `data_structures.simple` package.""" + +import sys + +from data_structures.simple.linked_list import LinkedList +from data_structures.simple.stack import ListStack, LinkedStack +from data_structures.simple.queue import Queue +from data_structures.simple.bst import BinarySearchTree +from data_structures.simple.graph import Graph +from data_structures.simple.trie import Trie +from data_structures.simple.union_find import UnionFind +from data_structures.simple.min_heap import MinHeap + + +def run() -> None: + failures = [] + + try: + ll = LinkedList() + ll.append(1) + ll.append(2) + ll.prepend(0) + assert ll.to_list() == [0, 1, 2] + assert ll.find(1) is not None + assert ll.delete(1) is True + assert ll.to_list() == [0, 2] + print("LinkedList: OK") + except Exception as e: + failures.append(('LinkedList', e)) + + try: + s = ListStack() + s.push(1) + s.push(2) + assert s.peek() == 2 + assert s.pop() == 2 + + ls = LinkedStack() + ls.push('a') + ls.push('b') + assert ls.peek() == 'b' + assert ls.pop() == 'b' + print("Stacks: OK") + except Exception as e: + failures.append(('Stacks', e)) + + try: + q = Queue() + q.enqueue(1) + q.enqueue(2) + assert q.peek() == 1 + assert q.dequeue() == 1 + print("Queue: OK") + except Exception as e: + failures.append(('Queue', e)) + + try: + bst = BinarySearchTree() + for v in [3, 1, 4, 2]: + bst.insert(v) + assert bst.search(2) + assert bst.inorder() == [1, 2, 3, 4] + print("BST: OK") + except Exception as e: + failures.append(('BST', e)) + + try: + g = Graph() + g.add_edge(1, 2) + g.add_edge(1, 3) + order_bfs = g.bfs(1) + assert 1 in order_bfs + order_dfs = g.dfs(1) + assert 1 in order_dfs + print("Graph: OK") + except Exception as e: + failures.append(('Graph', e)) + + try: + t = Trie() + t.insert('hello') + assert t.search('hello') + assert t.starts_with('hel') + print("Trie: OK") + except Exception as e: + failures.append(('Trie', e)) + + try: + uf = UnionFind() + for x in [1, 2, 3]: + uf.make_set(x) + uf.union(1, 2) + assert uf.find(1) == uf.find(2) + print("UnionFind: OK") + except Exception as e: + failures.append(('UnionFind', e)) + + try: + h = MinHeap() + h.push(5) + h.push(1) + h.push(3) + assert h.peek() == 1 + assert h.pop() == 1 + print("MinHeap: OK") + except Exception as e: + failures.append(('MinHeap', e)) + + if failures: + print('\nFAILURES:') + for name, exc in failures: + print(f"- {name}: {exc}") + sys.exit(2) + else: + print('\nAll data_structures.simple smoke tests passed.') + + +if __name__ == '__main__': + run() diff --git a/data_structures/simple/stack.py b/data_structures/simple/stack.py new file mode 100644 index 000000000000..cf48cfef49ae --- /dev/null +++ b/data_structures/simple/stack.py @@ -0,0 +1,34 @@ +from typing import Any, Optional +from .linked_list import LinkedList + + +class ListStack: + def __init__(self): + self._data = [] + + def push(self, v: Any) -> None: + self._data.append(v) + + def pop(self) -> Any: + return self._data.pop() + + def peek(self) -> Optional[Any]: + return self._data[-1] if self._data else None + + +class LinkedStack: + def __init__(self): + self._list = LinkedList() + + def push(self, v: Any) -> None: + self._list.prepend(v) + + def pop(self) -> Any: + if not self._list.head: + raise IndexError("pop from empty stack") + val = self._list.head.value + self._list.head = self._list.head.next + return val + + def peek(self) -> Optional[Any]: + return self._list.head.value if self._list.head else None diff --git a/data_structures/simple/trie.py b/data_structures/simple/trie.py new file mode 100644 index 000000000000..5bcb849ae0e7 --- /dev/null +++ b/data_structures/simple/trie.py @@ -0,0 +1,34 @@ +from typing import Dict, Any + + +class TrieNode: + def __init__(self): + self.children: Dict[str, TrieNode] = {} + self.end: bool = False + + +class Trie: + def __init__(self): + self.root = TrieNode() + + def insert(self, word: str) -> None: + node = self.root + for ch in word: + node = node.children.setdefault(ch, TrieNode()) + node.end = True + + def search(self, word: str) -> bool: + node = self.root + for ch in word: + node = node.children.get(ch) + if node is None: + return False + return node.end + + def starts_with(self, prefix: str) -> bool: + node = self.root + for ch in prefix: + node = node.children.get(ch) + if node is None: + return False + return True diff --git a/data_structures/simple/union_find.py b/data_structures/simple/union_find.py new file mode 100644 index 000000000000..61ae9e4559f8 --- /dev/null +++ b/data_structures/simple/union_find.py @@ -0,0 +1,30 @@ +from typing import Dict, Any + + +class UnionFind: + def __init__(self): + self.parent: Dict[Any, Any] = {} + self.rank: Dict[Any, int] = {} + + def make_set(self, x: Any) -> None: + if x not in self.parent: + self.parent[x] = x + self.rank[x] = 0 + + def find(self, x: Any) -> Any: + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x: Any, y: Any) -> None: + xroot = self.find(x) + yroot = self.find(y) + if xroot == yroot: + return + if self.rank[xroot] < self.rank[yroot]: + self.parent[xroot] = yroot + elif self.rank[xroot] > self.rank[yroot]: + self.parent[yroot] = xroot + else: + self.parent[yroot] = xroot + self.rank[xroot] += 1 From d005a615f245e525edef83c315dee5be12d1d497 Mon Sep 17 00:00:00 2001 From: athikha-faiz_infosys Date: Thu, 16 Oct 2025 13:32:06 +0530 Subject: [PATCH 4/7] adding some programs for data structures using python From b52536b62a6c0dd4962463392eb18556309bb9ef Mon Sep 17 00:00:00 2001 From: athikha-faiz_infosys Date: Fri, 17 Oct 2025 18:12:59 +0530 Subject: [PATCH 5/7] Fix doctest imports for stacks.balanced_parentheses and add doctest runner script; minor improvements --- .../binary_search_tree_recursive.py | 641 ------------------ .../stacks/balanced_parentheses.py | 18 +- scripts/run_doctests.py | 98 +++ 3 files changed, 115 insertions(+), 642 deletions(-) delete mode 100644 data_structures/binary_tree/binary_search_tree_recursive.py create mode 100644 scripts/run_doctests.py diff --git a/data_structures/binary_tree/binary_search_tree_recursive.py b/data_structures/binary_tree/binary_search_tree_recursive.py deleted file mode 100644 index d94ac5253360..000000000000 --- a/data_structures/binary_tree/binary_search_tree_recursive.py +++ /dev/null @@ -1,641 +0,0 @@ -""" -This is a python3 implementation of binary search tree using recursion - -To run tests: -python -m unittest binary_search_tree_recursive.py - -To run an example: -python binary_search_tree_recursive.py -""" - -from __future__ import annotations - -import unittest -from collections.abc import Iterator - -import pytest - - -class Node: - def __init__(self, label: int, parent: Node | None) -> None: - self.label = label - self.parent = parent - self.left: Node | None = None - self.right: Node | None = None - - -class BinarySearchTree: - def __init__(self) -> None: - self.root: Node | None = None - - def empty(self) -> None: - """ - Empties the tree - - >>> t = BinarySearchTree() - >>> assert t.root is None - >>> t.put(8) - >>> assert t.root is not None - """ - self.root = None - - def is_empty(self) -> bool: - """ - Checks if the tree is empty - - >>> t = BinarySearchTree() - >>> t.is_empty() - True - >>> t.put(8) - >>> t.is_empty() - False - """ - return self.root is None - - def put(self, label: int) -> None: - """ - Put a new node in the tree - - >>> t = BinarySearchTree() - >>> t.put(8) - >>> assert t.root.parent is None - >>> assert t.root.label == 8 - - >>> t.put(10) - >>> assert t.root.right.parent == t.root - >>> assert t.root.right.label == 10 - - >>> t.put(3) - >>> assert t.root.left.parent == t.root - >>> assert t.root.left.label == 3 - """ - self.root = self._put(self.root, label) - - def _put(self, node: Node | None, label: int, parent: Node | None = None) -> Node: - if node is None: - node = Node(label, parent) - elif label < node.label: - node.left = self._put(node.left, label, node) - elif label > node.label: - node.right = self._put(node.right, label, node) - else: - msg = f"Node with label {label} already exists" - raise ValueError(msg) - - return node - - def search(self, label: int) -> Node: - """ - Searches a node in the tree - - >>> t = BinarySearchTree() - >>> t.put(8) - >>> t.put(10) - >>> node = t.search(8) - >>> assert node.label == 8 - - >>> node = t.search(3) - Traceback (most recent call last): - ... - ValueError: Node with label 3 does not exist - """ - return self._search(self.root, label) - - def _search(self, node: Node | None, label: int) -> Node: - if node is None: - msg = f"Node with label {label} does not exist" - raise ValueError(msg) - elif label < node.label: - node = self._search(node.left, label) - elif label > node.label: - node = self._search(node.right, label) - - return node - - def remove(self, label: int) -> None: - """ - Removes a node in the tree - - >>> t = BinarySearchTree() - >>> t.put(8) - >>> t.put(10) - >>> t.remove(8) - >>> assert t.root.label == 10 - - >>> t.remove(3) - Traceback (most recent call last): - ... - ValueError: Node with label 3 does not exist - """ - node = self.search(label) - if node.right and node.left: - lowest_node = self._get_lowest_node(node.right) - lowest_node.left = node.left - lowest_node.right = node.right - node.left.parent = lowest_node - if node.right: - node.right.parent = lowest_node - self._reassign_nodes(node, lowest_node) - elif not node.right and node.left: - self._reassign_nodes(node, node.left) - elif node.right and not node.left: - self._reassign_nodes(node, node.right) - else: - self._reassign_nodes(node, None) - - def _reassign_nodes(self, node: Node, new_children: Node | None) -> None: - if new_children: - new_children.parent = node.parent - - if node.parent: - if node.parent.right == node: - node.parent.right = new_children - else: - node.parent.left = new_children - else: - self.root = new_children - - def _get_lowest_node(self, node: Node) -> Node: - if node.left: - lowest_node = self._get_lowest_node(node.left) - else: - lowest_node = node - self._reassign_nodes(node, node.right) - - return lowest_node - - def exists(self, label: int) -> bool: - """ - Checks if a node exists in the tree - - >>> t = BinarySearchTree() - >>> t.put(8) - >>> t.put(10) - >>> t.exists(8) - True - - >>> t.exists(3) - False - """ - try: - self.search(label) - return True - except ValueError: - return False - - def get_max_label(self) -> int: - """ - Gets the max label inserted in the tree - - >>> t = BinarySearchTree() - >>> t.get_max_label() - Traceback (most recent call last): - ... - ValueError: Binary search tree is empty - - >>> t.put(8) - >>> t.put(10) - >>> t.get_max_label() - 10 - """ - if self.root is None: - raise ValueError("Binary search tree is empty") - - node = self.root - while node.right is not None: - node = node.right - - return node.label - - def get_min_label(self) -> int: - """ - Gets the min label inserted in the tree - - >>> t = BinarySearchTree() - >>> t.get_min_label() - Traceback (most recent call last): - ... - ValueError: Binary search tree is empty - - >>> t.put(8) - >>> t.put(10) - >>> t.get_min_label() - 8 - """ - if self.root is None: - raise ValueError("Binary search tree is empty") - - node = self.root - while node.left is not None: - node = node.left - - return node.label - - def inorder_traversal(self) -> Iterator[Node]: - """ - Return the inorder traversal of the tree - - >>> t = BinarySearchTree() - >>> [i.label for i in t.inorder_traversal()] - [] - - >>> t.put(8) - >>> t.put(10) - >>> t.put(9) - >>> [i.label for i in t.inorder_traversal()] - [8, 9, 10] - """ - return self._inorder_traversal(self.root) - - def _inorder_traversal(self, node: Node | None) -> Iterator[Node]: - if node is not None: - yield from self._inorder_traversal(node.left) - yield node - yield from self._inorder_traversal(node.right) - - def preorder_traversal(self) -> Iterator[Node]: - """ - Return the preorder traversal of the tree - - >>> t = BinarySearchTree() - >>> [i.label for i in t.preorder_traversal()] - [] - - >>> t.put(8) - >>> t.put(10) - >>> t.put(9) - >>> [i.label for i in t.preorder_traversal()] - [8, 10, 9] - """ - return self._preorder_traversal(self.root) - - def _preorder_traversal(self, node: Node | None) -> Iterator[Node]: - if node is not None: - yield node - yield from self._preorder_traversal(node.left) - yield from self._preorder_traversal(node.right) - - -class BinarySearchTreeTest(unittest.TestCase): - @staticmethod - def _get_binary_search_tree() -> BinarySearchTree: - r""" - 8 - / \ - 3 10 - / \ \ - 1 6 14 - / \ / - 4 7 13 - \ - 5 - """ - t = BinarySearchTree() - t.put(8) - t.put(3) - t.put(6) - t.put(1) - t.put(10) - t.put(14) - t.put(13) - t.put(4) - t.put(7) - t.put(5) - - return t - - def test_put(self) -> None: - t = BinarySearchTree() - assert t.is_empty() - - t.put(8) - r""" - 8 - """ - assert t.root is not None - assert t.root.parent is None - assert t.root.label == 8 - - t.put(10) - r""" - 8 - \ - 10 - """ - assert t.root.right is not None - assert t.root.right.parent == t.root - assert t.root.right.label == 10 - - t.put(3) - r""" - 8 - / \ - 3 10 - """ - assert t.root.left is not None - assert t.root.left.parent == t.root - assert t.root.left.label == 3 - - t.put(6) - r""" - 8 - / \ - 3 10 - \ - 6 - """ - assert t.root.left.right is not None - assert t.root.left.right.parent == t.root.left - assert t.root.left.right.label == 6 - - t.put(1) - r""" - 8 - / \ - 3 10 - / \ - 1 6 - """ - assert t.root.left.left is not None - assert t.root.left.left.parent == t.root.left - assert t.root.left.left.label == 1 - - with pytest.raises(ValueError): - t.put(1) - - def test_search(self) -> None: - t = self._get_binary_search_tree() - - node = t.search(6) - assert node.label == 6 - - node = t.search(13) - assert node.label == 13 - - with pytest.raises(ValueError): - t.search(2) - - def test_remove(self) -> None: - t = self._get_binary_search_tree() - - t.remove(13) - r""" - 8 - / \ - 3 10 - / \ \ - 1 6 14 - / \ - 4 7 - \ - 5 - """ - assert t.root is not None - assert t.root.right is not None - assert t.root.right.right is not None - assert t.root.right.right.right is None - assert t.root.right.right.left is None - - t.remove(7) - r""" - 8 - / \ - 3 10 - / \ \ - 1 6 14 - / - 4 - \ - 5 - """ - assert t.root.left is not None - assert t.root.left.right is not None - assert t.root.left.right.left is not None - assert t.root.left.right.right is None - assert t.root.left.right.left.label == 4 - - t.remove(6) - r""" - 8 - / \ - 3 10 - / \ \ - 1 4 14 - \ - 5 - """ - assert t.root.left.left is not None - assert t.root.left.right.right is not None - assert t.root.left.left.label == 1 - assert t.root.left.right.label == 4 - assert t.root.left.right.right.label == 5 - assert t.root.left.right.left is None - assert t.root.left.left.parent == t.root.left - assert t.root.left.right.parent == t.root.left - - t.remove(3) - r""" - 8 - / \ - 4 10 - / \ \ - 1 5 14 - """ - assert t.root is not None - assert t.root.left.label == 4 - assert t.root.left.right.label == 5 - assert t.root.left.left.label == 1 - assert t.root.left.parent == t.root - assert t.root.left.left.parent == t.root.left - assert t.root.left.right.parent == t.root.left - - t.remove(4) - r""" - 8 - / \ - 5 10 - / \ - 1 14 - """ - assert t.root.left is not None - assert t.root.left.left is not None - assert t.root.left.label == 5 - assert t.root.left.right is None - assert t.root.left.left.label == 1 - assert t.root.left.parent == t.root - assert t.root.left.left.parent == t.root.left - - def test_remove_2(self) -> None: - t = self._get_binary_search_tree() - - t.remove(3) - r""" - 8 - / \ - 4 10 - / \ \ - 1 6 14 - / \ / - 5 7 13 - """ - assert t.root is not None - assert t.root.left is not None - assert t.root.left.left is not None - assert t.root.left.right is not None - assert t.root.left.right.left is not None - assert t.root.left.right.right is not None - assert t.root.left.label == 4 - assert t.root.left.right.label == 6 - assert t.root.left.left.label == 1 - assert t.root.left.right.right.label == 7 - assert t.root.left.right.left.label == 5 - assert t.root.left.parent == t.root - assert t.root.left.right.parent == t.root.left - assert t.root.left.left.parent == t.root.left - assert t.root.left.right.left.parent == t.root.left.right - - def test_empty(self) -> None: - t = self._get_binary_search_tree() - t.empty() - assert t.root is None - - def test_is_empty(self) -> None: - t = self._get_binary_search_tree() - assert not t.is_empty() - - t.empty() - assert t.is_empty() - - def test_exists(self) -> None: - t = self._get_binary_search_tree() - - assert t.exists(6) - assert not t.exists(-1) - - def test_get_max_label(self) -> None: - t = self._get_binary_search_tree() - - assert t.get_max_label() == 14 - - t.empty() - with pytest.raises(ValueError): - t.get_max_label() - - def test_get_min_label(self) -> None: - t = self._get_binary_search_tree() - - assert t.get_min_label() == 1 - - t.empty() - with pytest.raises(ValueError): - t.get_min_label() - - def test_inorder_traversal(self) -> None: - t = self._get_binary_search_tree() - - inorder_traversal_nodes = [i.label for i in t.inorder_traversal()] - assert inorder_traversal_nodes == [1, 3, 4, 5, 6, 7, 8, 10, 13, 14] - - def test_preorder_traversal(self) -> None: - t = self._get_binary_search_tree() - - preorder_traversal_nodes = [i.label for i in t.preorder_traversal()] - assert preorder_traversal_nodes == [8, 3, 1, 6, 4, 5, 7, 10, 14, 13] - - -def binary_search_tree_example() -> None: - r""" - Example - 8 - / \ - 3 10 - / \ \ - 1 6 14 - / \ / - 4 7 13 - \ - 5 - - Example After Deletion - 4 - / \ - 1 7 - \ - 5 - - """ - - t = BinarySearchTree() - t.put(8) - t.put(3) - t.put(6) - t.put(1) - t.put(10) - t.put(14) - t.put(13) - t.put(4) - t.put(7) - t.put(5) - - print( - """ - 8 - / \\ - 3 10 - / \\ \\ - 1 6 14 - / \\ / - 4 7 13 - \\ - 5 - """ - ) - - print("Label 6 exists:", t.exists(6)) - print("Label 13 exists:", t.exists(13)) - print("Label -1 exists:", t.exists(-1)) - print("Label 12 exists:", t.exists(12)) - - # Prints all the elements of the list in inorder traversal - inorder_traversal_nodes = [i.label for i in t.inorder_traversal()] - print("Inorder traversal:", inorder_traversal_nodes) - - # Prints all the elements of the list in preorder traversal - preorder_traversal_nodes = [i.label for i in t.preorder_traversal()] - print("Preorder traversal:", preorder_traversal_nodes) - - print("Max. label:", t.get_max_label()) - print("Min. label:", t.get_min_label()) - - # Delete elements - print("\nDeleting elements 13, 10, 8, 3, 6, 14") - print( - """ - 4 - / \\ - 1 7 - \\ - 5 - """ - ) - t.remove(13) - t.remove(10) - t.remove(8) - t.remove(3) - t.remove(6) - t.remove(14) - - # Prints all the elements of the list in inorder traversal after delete - inorder_traversal_nodes = [i.label for i in t.inorder_traversal()] - print("Inorder traversal after delete:", inorder_traversal_nodes) - - # Prints all the elements of the list in preorder traversal after delete - preorder_traversal_nodes = [i.label for i in t.preorder_traversal()] - print("Preorder traversal after delete:", preorder_traversal_nodes) - - print("Max. label:", t.get_max_label()) - print("Min. label:", t.get_min_label()) - - -if __name__ == "__main__": - binary_search_tree_example() diff --git a/data_structures/stacks/balanced_parentheses.py b/data_structures/stacks/balanced_parentheses.py index 928815bb2111..aa3890ac4fb2 100644 --- a/data_structures/stacks/balanced_parentheses.py +++ b/data_structures/stacks/balanced_parentheses.py @@ -1,4 +1,20 @@ -from .stack import Stack +try: + from .stack import Stack +except Exception: # pragma: no cover - fallback for direct script execution / doctest + # When this module is executed directly (for example via `python -m doctest`), + # package-relative imports like `from .stack import Stack` may fail because + # there's no package context. Load the sibling `stack.py` file directly as a + # module so the functions here still work when run from the filesystem. + import importlib.util + import sys + from pathlib import Path + + stack_path = Path(__file__).with_name("stack.py") + spec = importlib.util.spec_from_file_location("data_structures.stacks.stack", str(stack_path)) + _stack = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = _stack + spec.loader.exec_module(_stack) # type: ignore[attr-defined] + Stack = _stack.Stack def balanced_parentheses(parentheses: str) -> bool: diff --git a/scripts/run_doctests.py b/scripts/run_doctests.py new file mode 100644 index 000000000000..48978c035b1c --- /dev/null +++ b/scripts/run_doctests.py @@ -0,0 +1,98 @@ +"""Run doctests for all .py files under data_structures and summarize results. + +Usage: python scripts/run_doctests.py + +The script will try to import each file as a package module (e.g. data_structures.x.y) +so relative imports work. If import fails, it will fall back to running doctest.testfile +on the file path. + +Exits with non-zero status when any failures or uncaught exceptions occur. +""" + +from __future__ import annotations + +import doctest +import importlib +import importlib.util +import os +import sys +import traceback +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] # Python/ (project root where data_structures lives) +DATA_DIR = ROOT / "data_structures" + + +def main() -> int: + sys.path.insert(0, str(ROOT)) + + results = [] # tuples: (relpath, status, details) + total_tests = 0 + total_failures = 0 + + for py in sorted(DATA_DIR.rglob("*.py")): + rel = py.relative_to(ROOT) + module_name = ".".join(rel.with_suffix("").parts) + # skip compiled or vendored files if any + print(f"\n=== Running doctests for {rel} (module {module_name}) ===") + try: + # Try to import as module first (so package-relative imports work) + module = importlib.import_module(module_name) + # run doctest on module + res = doctest.testmod(module, verbose=False) + failures, tests = res + total_tests += tests + total_failures += failures + status = "ok" if failures == 0 else "fail" + details = f"imported module; tests={tests}, failures={failures}" + print(details) + if failures: + # rerun verbosely to show failing examples + print("--- Verbose output for failures ---") + doctest.testmod(module, verbose=True) + results.append((str(rel), status, details)) + except Exception as e: + # Import failed; try doctest.testfile on the path + print(f"Import failed: {e.__class__.__name__}: {e}") + traceback.print_exc() + try: + failures, tests = doctest.testfile(str(py), module_relative=False) + total_tests += tests + total_failures += failures + status = "ok" if failures == 0 else "fail" + details = f"testfile fallback; tests={tests}, failures={failures}" + print(details) + if failures: + print("--- Verbose output for failures ---") + doctest.testfile(str(py), module_relative=False, verbose=True) + results.append((str(rel), status, details)) + except Exception as ex2: + print(f"Fallback testfile raised {ex2.__class__.__name__}: {ex2}") + traceback.print_exc() + results.append((str(rel), "error", f"import error: {e}; fallback error: {ex2}")) + + # Summary + print("\n=== Summary ===") + print(f"Files checked: {len(results)}") + print(f"Total doctest examples run: {total_tests}") + print(f"Total failures: {total_failures}") + print("") + for f, status, details in results: + if status != "ok": + print(f"{f}: {status} - {details}") + + if total_failures > 0: + print("\nSome doctests failed.") + return 1 + + # If any file had 'error' status, exit non-zero + if any(r[1] == "error" for r in results): + print("\nSome files raised errors while running doctests.") + return 2 + + print("\nAll doctests passed.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From a84ebb32aa4ae6b08f5f6f14f9e4af7c9f286ca8 Mon Sep 17 00:00:00 2001 From: athikha-faiz_infosys Date: Fri, 17 Oct 2025 18:33:45 +0530 Subject: [PATCH 6/7] Add optional local_search to SimulatedAnnealing; implement simple_local_search and doctest demonstrating benefit --- simulated_annealing/simulated_annealing.py | 78 +++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/simulated_annealing/simulated_annealing.py b/simulated_annealing/simulated_annealing.py index 1ad354dde574..b797c09402d4 100644 --- a/simulated_annealing/simulated_annealing.py +++ b/simulated_annealing/simulated_annealing.py @@ -21,6 +21,8 @@ def __init__( min_temperature: float = 1e-3, iterations_per_temp: int = 100, neighbor_scale: float = 0.1, + local_search: Optional[Callable[[Sequence[float], Callable[[Sequence[float]], float], Callable[[Sequence[float]], Sequence[float]], int], Tuple[Sequence[float], float]]] = None, + local_search_iters: int = 10, seed: Optional[int] = None, ): self.func = func @@ -33,6 +35,9 @@ def __init__( self.min_temperature = float(min_temperature) self.iterations_per_temp = int(iterations_per_temp) self.neighbor_scale = float(neighbor_scale) + # local_search: callable(solution, func, neighbor_fn, iters) -> (improved_solution, improved_cost) + self.local_search = local_search + self.local_search_iters = int(local_search_iters) if seed is not None: random.seed(seed) @@ -91,7 +96,17 @@ def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object] return best, best_cost, history candidate = self._neighbor(current) - candidate_cost = float(self.func(candidate)) + # Optionally refine candidate with local search before evaluating/accepting + if self.local_search is not None: + try: + improved, improved_cost = self.local_search(candidate, self.func, self._neighbor, self.local_search_iters) + candidate = list(improved) + candidate_cost = float(improved_cost) + except Exception: + # Fall back to plain candidate evaluation if local search fails + candidate_cost = float(self.func(candidate)) + else: + candidate_cost = float(self.func(candidate)) delta = candidate_cost - current_cost if self._accept(delta, temp): current = candidate @@ -131,5 +146,66 @@ def _test_quadratic(): print("best:", best, "cost:", cost) +def simple_local_search(solution: Sequence[float], func: Callable[[Sequence[float]], float], neighbor_fn: Callable[[Sequence[float]], Sequence[float]], iterations: int = 10) -> Tuple[Sequence[float], float]: + """A tiny hill-climbing local search that repeatedly accepts improving neighbors. + + Parameters + - solution: starting solution sequence + - func: objective function (lower is better) + - neighbor_fn: function that given a solution returns a new neighbor solution + - iterations: number of neighbor attempts + + Returns a tuple (best_solution, best_cost). + + >>> func = lambda x: (x[0] - 5) ** 2 + >>> start = [0.0] + >>> def neighbor(x): + ... return [x[0] + 0.5] + >>> best, cost = simple_local_search(start, func, neighbor, iterations=5) + >>> best[0] > start[0] + True + >>> cost == func(best) + True + """ + best = list(solution) + best_cost = float(func(best)) + for _ in range(int(iterations)): + cand = neighbor_fn(best) + cand_cost = float(func(cand)) + if cand_cost < best_cost: + best = list(cand) + best_cost = cand_cost + return best, best_cost + + +def _doctest_local_search_benefit(): + """Demonstrate that providing a local_search can improve or match the solution found by SimulatedAnnealing. + + The test uses a deterministic seed so the result is reproducible in doctest. + + >>> func = lambda x: (x[0] - 5) ** 2 + >>> sa1 = SimulatedAnnealing(func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=20, seed=1) + >>> best1, cost1, _ = sa1.optimize(max_steps=200) + >>> # define a deterministic, greedy local search that moves toward the known minimum (5.0) + >>> def my_local_search(sol, f, neighbour, iters): + ... s = list(sol) + ... bestc = float(f(s)) + ... for _ in range(int(iters)): + ... # move halfway toward 5.0 (gradient-free, deterministic) + ... s[0] = s[0] + 0.5 * (5.0 - s[0]) + ... c = float(f(s)) + ... if c < bestc: + ... bestc = c + ... else: + ... break + ... return s, bestc + >>> sa2 = SimulatedAnnealing(func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=20, seed=1, local_search=my_local_search, local_search_iters=5) + >>> best2, cost2, _ = sa2.optimize(max_steps=200) + >>> cost2 <= cost1 + True + """ + pass + + if __name__ == "__main__": _test_quadratic() From f3cb719b69ea151e760613426e6ba62a875e25b0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Oct 2025 13:24:49 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- data_structures/simple/run_ds_tests.py | 36 +++++++------- .../stacks/balanced_parentheses.py | 4 +- scripts/run_doctests.py | 8 +++- simulated_annealing/__init__.py | 1 + simulated_annealing/example.py | 18 +++++-- simulated_annealing/gui.py | 37 ++++++++++++--- simulated_annealing/simulated_annealing.py | 47 ++++++++++++++++--- tests/test_simulated_annealing.py | 21 +++++++-- tests/test_tsp.py | 2 +- 9 files changed, 132 insertions(+), 42 deletions(-) diff --git a/data_structures/simple/run_ds_tests.py b/data_structures/simple/run_ds_tests.py index 67aa19b36b47..c321e93dec11 100644 --- a/data_structures/simple/run_ds_tests.py +++ b/data_structures/simple/run_ds_tests.py @@ -26,7 +26,7 @@ def run() -> None: assert ll.to_list() == [0, 2] print("LinkedList: OK") except Exception as e: - failures.append(('LinkedList', e)) + failures.append(("LinkedList", e)) try: s = ListStack() @@ -36,13 +36,13 @@ def run() -> None: assert s.pop() == 2 ls = LinkedStack() - ls.push('a') - ls.push('b') - assert ls.peek() == 'b' - assert ls.pop() == 'b' + ls.push("a") + ls.push("b") + assert ls.peek() == "b" + assert ls.pop() == "b" print("Stacks: OK") except Exception as e: - failures.append(('Stacks', e)) + failures.append(("Stacks", e)) try: q = Queue() @@ -52,7 +52,7 @@ def run() -> None: assert q.dequeue() == 1 print("Queue: OK") except Exception as e: - failures.append(('Queue', e)) + failures.append(("Queue", e)) try: bst = BinarySearchTree() @@ -62,7 +62,7 @@ def run() -> None: assert bst.inorder() == [1, 2, 3, 4] print("BST: OK") except Exception as e: - failures.append(('BST', e)) + failures.append(("BST", e)) try: g = Graph() @@ -74,16 +74,16 @@ def run() -> None: assert 1 in order_dfs print("Graph: OK") except Exception as e: - failures.append(('Graph', e)) + failures.append(("Graph", e)) try: t = Trie() - t.insert('hello') - assert t.search('hello') - assert t.starts_with('hel') + t.insert("hello") + assert t.search("hello") + assert t.starts_with("hel") print("Trie: OK") except Exception as e: - failures.append(('Trie', e)) + failures.append(("Trie", e)) try: uf = UnionFind() @@ -93,7 +93,7 @@ def run() -> None: assert uf.find(1) == uf.find(2) print("UnionFind: OK") except Exception as e: - failures.append(('UnionFind', e)) + failures.append(("UnionFind", e)) try: h = MinHeap() @@ -104,16 +104,16 @@ def run() -> None: assert h.pop() == 1 print("MinHeap: OK") except Exception as e: - failures.append(('MinHeap', e)) + failures.append(("MinHeap", e)) if failures: - print('\nFAILURES:') + print("\nFAILURES:") for name, exc in failures: print(f"- {name}: {exc}") sys.exit(2) else: - print('\nAll data_structures.simple smoke tests passed.') + print("\nAll data_structures.simple smoke tests passed.") -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/data_structures/stacks/balanced_parentheses.py b/data_structures/stacks/balanced_parentheses.py index aa3890ac4fb2..2658f3f94f1d 100644 --- a/data_structures/stacks/balanced_parentheses.py +++ b/data_structures/stacks/balanced_parentheses.py @@ -10,7 +10,9 @@ from pathlib import Path stack_path = Path(__file__).with_name("stack.py") - spec = importlib.util.spec_from_file_location("data_structures.stacks.stack", str(stack_path)) + spec = importlib.util.spec_from_file_location( + "data_structures.stacks.stack", str(stack_path) + ) _stack = importlib.util.module_from_spec(spec) sys.modules[spec.name] = _stack spec.loader.exec_module(_stack) # type: ignore[attr-defined] diff --git a/scripts/run_doctests.py b/scripts/run_doctests.py index 48978c035b1c..bd5303b216e3 100644 --- a/scripts/run_doctests.py +++ b/scripts/run_doctests.py @@ -19,7 +19,9 @@ import traceback from pathlib import Path -ROOT = Path(__file__).resolve().parents[1] # Python/ (project root where data_structures lives) +ROOT = ( + Path(__file__).resolve().parents[1] +) # Python/ (project root where data_structures lives) DATA_DIR = ROOT / "data_structures" @@ -69,7 +71,9 @@ def main() -> int: except Exception as ex2: print(f"Fallback testfile raised {ex2.__class__.__name__}: {ex2}") traceback.print_exc() - results.append((str(rel), "error", f"import error: {e}; fallback error: {ex2}")) + results.append( + (str(rel), "error", f"import error: {e}; fallback error: {ex2}") + ) # Summary print("\n=== Summary ===") diff --git a/simulated_annealing/__init__.py b/simulated_annealing/__init__.py index bfaf1d331c90..19094b182a70 100644 --- a/simulated_annealing/__init__.py +++ b/simulated_annealing/__init__.py @@ -4,6 +4,7 @@ - SimulatedAnnealing: core optimizer class - example_functions: a small collection of test functions """ + from .simulated_annealing import SimulatedAnnealing from .example import example_functions diff --git a/simulated_annealing/example.py b/simulated_annealing/example.py index 4adb6de93943..034e3bce0ec1 100644 --- a/simulated_annealing/example.py +++ b/simulated_annealing/example.py @@ -8,7 +8,9 @@ def sphere(x: Sequence[float]) -> float: def rastrigin(x: Sequence[float]) -> float: # Rastrigin function (common test function) A = 10 - return A * len(x) + sum((v * v - A * __import__("math").cos(2 * __import__("math").pi * v)) for v in x) + return A * len(x) + sum( + (v * v - A * __import__("math").cos(2 * __import__("math").pi * v)) for v in x + ) example_functions: Dict[str, Callable[[Sequence[float]], float]] = { @@ -20,10 +22,18 @@ def rastrigin(x: Sequence[float]) -> float: def cli_example(): # CLI demo minimizing 2D sphere from .simulated_annealing import SimulatedAnnealing + func = sphere initial = [5.0, -3.0] bounds = [(-10, 10), (-10, 10)] - sa = SimulatedAnnealing(func, initial, bounds=bounds, temperature=50, cooling_rate=0.95, iterations_per_temp=200) + sa = SimulatedAnnealing( + func, + initial, + bounds=bounds, + temperature=50, + cooling_rate=0.95, + iterations_per_temp=200, + ) best, cost, history = sa.optimize() print("Best:", best) print("Cost:", cost) @@ -41,7 +51,9 @@ def tsp_example(): initial = [float(i) for i in init_tour] cost_fn = make_tsp_cost(coords) - sa = SimulatedAnnealing(cost_fn, initial, temperature=100.0, cooling_rate=0.995, iterations_per_temp=500) + sa = SimulatedAnnealing( + cost_fn, initial, temperature=100.0, cooling_rate=0.995, iterations_per_temp=500 + ) best, cost, history = sa.optimize() best_tour = vector_to_tour(best) print("Best tour:", best_tour) diff --git a/simulated_annealing/gui.py b/simulated_annealing/gui.py index a0ce4f1454f1..fd6fc270d3b3 100644 --- a/simulated_annealing/gui.py +++ b/simulated_annealing/gui.py @@ -4,6 +4,7 @@ from typing import Optional import matplotlib + matplotlib.use("TkAgg") from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg import matplotlib.pyplot as plt @@ -24,7 +25,12 @@ def __init__(self): ttk.Label(ctrl, text="Function:").pack(anchor=tk.W) self.func_var = tk.StringVar(value="sphere") - func_menu = ttk.Combobox(ctrl, textvariable=self.func_var, values=list(example_functions.keys()), state="readonly") + func_menu = ttk.Combobox( + ctrl, + textvariable=self.func_var, + values=list(example_functions.keys()), + state="readonly", + ) func_menu.pack(fill=tk.X) ttk.Label(ctrl, text="Initial (comma-separated)").pack(anchor=tk.W, pady=(8, 0)) @@ -32,7 +38,9 @@ def __init__(self): self.init_entry.insert(0, "5, -3") self.init_entry.pack(fill=tk.X) - ttk.Label(ctrl, text="Bounds (lo:hi comma-separated for each)").pack(anchor=tk.W, pady=(8, 0)) + ttk.Label(ctrl, text="Bounds (lo:hi comma-separated for each)").pack( + anchor=tk.W, pady=(8, 0) + ) self.bounds_entry = ttk.Entry(ctrl) self.bounds_entry.insert(0, "-10:10, -10:10") self.bounds_entry.pack(fill=tk.X) @@ -56,7 +64,9 @@ def __init__(self): self.run_btn.pack(fill=tk.X, pady=(12, 0)) self.stop_flag = threading.Event() - self.stop_btn = ttk.Button(ctrl, text="Stop", command=self._on_stop, state=tk.DISABLED) + self.stop_btn = ttk.Button( + ctrl, text="Stop", command=self._on_stop, state=tk.DISABLED + ) self.stop_btn.pack(fill=tk.X, pady=(6, 0)) # Right: plot @@ -65,7 +75,7 @@ def __init__(self): self.canvas = FigureCanvasTkAgg(fig, master=self) self.canvas.get_tk_widget().pack(side=tk.RIGHT, fill=tk.BOTH, expand=1) - self._plot_line, = self.ax.plot([], [], label="best_cost") + (self._plot_line,) = self.ax.plot([], [], label="best_cost") self.ax.set_xlabel("Iterations") self.ax.set_ylabel("Best cost") self.ax.grid(True) @@ -120,13 +130,22 @@ def _on_run(self): self.stop_flag.clear() def worker(): - sa = SimulatedAnnealing(func, initial, bounds=bounds, temperature=temp, cooling_rate=cooling, iterations_per_temp=iterations) + sa = SimulatedAnnealing( + func, + initial, + bounds=bounds, + temperature=temp, + cooling_rate=cooling, + iterations_per_temp=iterations, + ) def progress_cb(step, best_cost, current_cost): # schedule a plot update on the main thread self.after(0, lambda: self._update_plot_partial(step, best_cost)) - best, cost, history = sa.optimize(stop_event=self.stop_flag, progress_callback=progress_cb) + best, cost, history = sa.optimize( + stop_event=self.stop_flag, progress_callback=progress_cb + ) # update final plot on main thread self.after(0, lambda: self._on_complete(best, cost, history)) @@ -158,7 +177,11 @@ def _on_complete(self, best, cost, history): def _update_plot_partial(self, step: int, best_cost: float): # Append a new point to plot (x=step, y=best_cost) # We'll redraw full plot for simplicity - line_x = list(range(len(self.ax.lines[0].get_xdata()) + 1)) if self.ax.lines else [step] + line_x = ( + list(range(len(self.ax.lines[0].get_xdata()) + 1)) + if self.ax.lines + else [step] + ) if self.ax.lines: ydata = list(self.ax.lines[0].get_ydata()) ydata.append(best_cost) diff --git a/simulated_annealing/simulated_annealing.py b/simulated_annealing/simulated_annealing.py index b797c09402d4..62f634df7d90 100644 --- a/simulated_annealing/simulated_annealing.py +++ b/simulated_annealing/simulated_annealing.py @@ -21,7 +21,17 @@ def __init__( min_temperature: float = 1e-3, iterations_per_temp: int = 100, neighbor_scale: float = 0.1, - local_search: Optional[Callable[[Sequence[float], Callable[[Sequence[float]], float], Callable[[Sequence[float]], Sequence[float]], int], Tuple[Sequence[float], float]]] = None, + local_search: Optional[ + Callable[ + [ + Sequence[float], + Callable[[Sequence[float]], float], + Callable[[Sequence[float]], Sequence[float]], + int, + ], + Tuple[Sequence[float], float], + ] + ] = None, local_search_iters: int = 10, seed: Optional[int] = None, ): @@ -70,7 +80,12 @@ def _accept(self, delta: float, temp: float) -> bool: prob = 0.0 return random.random() < prob - def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object] = None, progress_callback: Optional[Callable[[int, float, float], None]] = None) -> Tuple[List[float], float, dict]: + def optimize( + self, + max_steps: Optional[int] = None, + stop_event: Optional[object] = None, + progress_callback: Optional[Callable[[int, float, float], None]] = None, + ) -> Tuple[List[float], float, dict]: """Run optimization and return best solution, cost, and history. New optional args: @@ -91,7 +106,10 @@ def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object] while temp > self.min_temperature: for _ in range(self.iterations_per_temp): # Check stop event - if stop_event is not None and getattr(stop_event, "is_set", lambda: False)(): + if ( + stop_event is not None + and getattr(stop_event, "is_set", lambda: False)() + ): self.current = current return best, best_cost, history @@ -99,7 +117,12 @@ def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object] # Optionally refine candidate with local search before evaluating/accepting if self.local_search is not None: try: - improved, improved_cost = self.local_search(candidate, self.func, self._neighbor, self.local_search_iters) + improved, improved_cost = self.local_search( + candidate, + self.func, + self._neighbor, + self.local_search_iters, + ) candidate = list(improved) candidate_cost = float(improved_cost) except Exception: @@ -120,7 +143,10 @@ def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object] history["current_costs"].append(current_cost) steps += 1 - if progress_callback is not None and steps % max(1, self.iterations_per_temp // 10) == 0: + if ( + progress_callback is not None + and steps % max(1, self.iterations_per_temp // 10) == 0 + ): try: progress_callback(steps, best_cost, current_cost) except Exception: @@ -141,12 +167,19 @@ def optimize(self, max_steps: Optional[int] = None, stop_event: Optional[object] def _test_quadratic(): # Simple test: minimize f(x) = (x-3)^2 func = lambda x: (x[0] - 3) ** 2 - sa = SimulatedAnnealing(func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=50) + sa = SimulatedAnnealing( + func, [0.0], bounds=[(-10, 10)], temperature=10, iterations_per_temp=50 + ) best, cost, hist = sa.optimize() print("best:", best, "cost:", cost) -def simple_local_search(solution: Sequence[float], func: Callable[[Sequence[float]], float], neighbor_fn: Callable[[Sequence[float]], Sequence[float]], iterations: int = 10) -> Tuple[Sequence[float], float]: +def simple_local_search( + solution: Sequence[float], + func: Callable[[Sequence[float]], float], + neighbor_fn: Callable[[Sequence[float]], Sequence[float]], + iterations: int = 10, +) -> Tuple[Sequence[float], float]: """A tiny hill-climbing local search that repeatedly accepts improving neighbors. Parameters diff --git a/tests/test_simulated_annealing.py b/tests/test_simulated_annealing.py index 76a912df1a07..1fbb253823f4 100644 --- a/tests/test_simulated_annealing.py +++ b/tests/test_simulated_annealing.py @@ -8,15 +8,30 @@ def sphere(x): class TestSimulatedAnnealing(unittest.TestCase): def test_minimize_sphere_1d(self): - sa = SimulatedAnnealing(sphere, [5.0], bounds=[(-10, 10)], temperature=10, cooling_rate=0.9, iterations_per_temp=50) + sa = SimulatedAnnealing( + sphere, + [5.0], + bounds=[(-10, 10)], + temperature=10, + cooling_rate=0.9, + iterations_per_temp=50, + ) best, cost, hist = sa.optimize() # Best should be near 0 with tiny cost self.assertLess(cost, 1e-2) def test_stop_event(self): import threading + stop = threading.Event() - sa = SimulatedAnnealing(sphere, [5.0], bounds=[(-10, 10)], temperature=10, cooling_rate=0.9, iterations_per_temp=1000) + sa = SimulatedAnnealing( + sphere, + [5.0], + bounds=[(-10, 10)], + temperature=10, + cooling_rate=0.9, + iterations_per_temp=1000, + ) # request stop immediately stop.set() best, cost, hist = sa.optimize(stop_event=stop) @@ -24,5 +39,5 @@ def test_stop_event(self): self.assertIsNotNone(best) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_tsp.py b/tests/test_tsp.py index a9ac0b9f0dcb..2badc6f035d5 100644 --- a/tests/test_tsp.py +++ b/tests/test_tsp.py @@ -16,5 +16,5 @@ def test_total_distance_cycle(self): self.assertGreater(d, 0) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()