diff --git a/.gitignore b/.gitignore index 78f088041..6bacf997e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,8 @@ cache *.db *.db-journal +*.html + # VSCode .VSCode .vsCode diff --git a/alembic/versions/1a19ce257672_add_browserevent_message.py b/alembic/versions/1a19ce257672_add_browserevent_message.py new file mode 100644 index 000000000..9f7341ae2 --- /dev/null +++ b/alembic/versions/1a19ce257672_add_browserevent_message.py @@ -0,0 +1,49 @@ +"""add BrowserEvent.message + +Revision ID: 1a19ce257672 +Revises: 8713b142f5de +Create Date: 2023-08-28 17:01:07.703670 + +""" +import sqlalchemy as sa + +from alembic import op +import openadapt + +# revision identifiers, used by Alembic. +revision = "1a19ce257672" +down_revision = "8713b142f5de" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "browser_event", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "recording_timestamp", + openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), + nullable=True, + ), + sa.Column("message", sa.JSON(), nullable=True), + sa.Column( + "timestamp", + openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["recording_timestamp"], + ["recording.timestamp"], + name=op.f("fk_browser_event_recording_timestamp_recording"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_browser_event")), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("browser_event") + # ### end Alembic commands ### diff --git a/chrome/background.js b/chrome/background.js new file mode 100644 index 000000000..0de8782a1 --- /dev/null +++ b/chrome/background.js @@ -0,0 +1,54 @@ +/** + * @file background.js + * @description This file is responsible for communicating with the native + * messaging host and the content script. + * @see https://docs.google.com/presentation/d/106AXW3sBe7-7E-zIggnMnaUKUXWAj_aAuSxBspTDcGk/edit#slide=id.p + */ + +const hostName = "openadapt"; +var port = null; // Native Messaging port +var lastMsg = null; + +/* + * Handle received messages from browser.js + */ +function onReceived(response) { + console.log(response); +} + +function onDisconnected() { + msg = "Failed to connect: " + chrome.runtime.lastError.message; // silence error + port = null; +} + +function connect() { + port = chrome.runtime.connectNative(hostName); + port.onMessage.addListener(onReceived); + port.onDisconnect.addListener(onDisconnected); +} + +/* + * Message listener for content script + */ +function messageListener(message, sender, sendResponse) { + const timestampThreshold = 30; // arbitrary threshold in milliseconds + + try { + if (lastMsg !== null) { + if ( + Math.abs(message.timestamp - lastMsg.timestamp) < timestampThreshold && + message.tagName === lastMsg.tagName && + message.action === lastMsg.action + ) { + return; + } + } + console.log({ message, sender, sendResponse }); + port.postMessage(message); // send to browser.py (native messaging host) + lastMsg = message; + } catch (e) { + connect(); + } +} +connect(); +chrome.runtime.onMessage.addListener(messageListener); diff --git a/chrome/browser.bat b/chrome/browser.bat new file mode 100644 index 000000000..75a4394bb --- /dev/null +++ b/chrome/browser.bat @@ -0,0 +1,3 @@ +@echo off + +python -u "P:\\OpenAdapt AI - MLDS AI\\cloned_repo\\OpenAdapt\\chrome\\browser.py" diff --git a/chrome/browser.py b/chrome/browser.py new file mode 100755 index 000000000..d25cc8d7e --- /dev/null +++ b/chrome/browser.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 # noqa: D205 + +"""Script for communicating with the browser extension. # noqa: D205 D415 +Usage: + See `native_chrome_extension/browser.bat`. +""" + +# Note that running python with the `-u` flag is required on Windows, +# in order to ensure that stdin and stdout are opened in binary, rather +# than text, mode. + +import json +import sqlite3 +import struct +import sys + +from openadapt import config, sockets + +SOCKETS = True +DBG_DATABASE = False + + +def get_message() -> dict: + """Read a message from stdin and decode it. + + Returns: + A dictionary representing the decoded message. + """ + raw_length = sys.stdin.buffer.read(4) + if len(raw_length) == 0: + sys.exit(0) + message_length = struct.unpack("@I", raw_length)[0] + message = sys.stdin.buffer.read(message_length).decode("utf-8") + return json.loads(message) + + +def encode_message(message_content: any) -> dict: + """Encode a message for transmission, given its content. + + Args: + message_content: The content of the message to be encoded. + + Returns: + A dictionary containing the encoded message. + """ + # https://docs.python.org/3/library/json.html#basic-usage + # To get the most compact JSON representation, you should specify + # (',', ':') to eliminate whitespace. + # We want the most compact representation because the browser rejects + # messages that exceed 1 MB. + encoded_content = json.dumps(message_content, separators=(",", ":")).encode("utf-8") + encoded_length = struct.pack("@I", len(encoded_content)) + return {"length": encoded_length, "content": encoded_content} + + +def send_message(encoded_message: dict) -> None: + """Send an encoded message to stdout. + + Args: + encoded_message: The encoded message to be sent. + """ + sys.stdout.buffer.write(encoded_message["length"]) + sys.stdout.buffer.write(encoded_message["content"]) + sys.stdout.buffer.flush() + + +def send_message_to_client(conn: sockets.Connection, message: dict) -> None: + """Send a message to the client. + + Args: + conn: The connection to the client. + message: The message to be sent. + """ + try: + conn.send(message) + except Exception as exc: + print(f"Error sending message to client: {exc}") + + +def main() -> None: + """Main function.""" + if DBG_DATABASE: + db_conn = sqlite3.connect("messages.db") + c = db_conn.cursor() + c.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + message TEXT NOT NULL + ) + """) + if SOCKETS: + conn = sockets.create_server_connection(config.SOCKET_PORT) + while True: + if SOCKETS and conn.closed: + conn = sockets.create_server_connection(config.SOCKET_PORT) + message = get_message() + + # Log the message to the database + if DBG_DATABASE: + c.execute( + "INSERT INTO messages (message) VALUES (?)", (json.dumps(message),) + ) + db_conn.commit() + response = {"message": "Data received and logged successfully!"} + if message: + encoded_response = encode_message(response) + send_message(encoded_response) + if SOCKETS: + send_message_to_client(conn, message) + + +if __name__ == "__main__": + main() diff --git a/chrome/content.js b/chrome/content.js new file mode 100644 index 000000000..bc25ba335 --- /dev/null +++ b/chrome/content.js @@ -0,0 +1,124 @@ +/** + * @file content.js + * @description This file is injected into the web page and is responsible for + * capturing DOM changes and sending them to the background script. + */ + +let logged = false; +let ignoreAttributes = new Set(); +const elements = {}; + +/* + * Function to send a message to the background script + */ +function sendMessageToBackgroundScript(message) { + chrome.runtime.sendMessage(message); +} + +/* + * Function to capture initial document state and + * send it to the background script + */ +function captureDocumentState() { + const documentBody = document.body.outerHTML; + const documentHead = document.head.outerHTML; + const page_url = window.location.href; + + sendMessageToBackgroundScript({ + action: "captureDocumentState", + documentBody: documentBody, + documentHead: documentHead, + elements: elements, + url: page_url, + timestamp: Date.now(), + }); +} + +function handleElementClick(event) { + const element = event.target; + const tagName = element.tagName; + const { x, y } = elements[element.id] || {}; + const value = elements[element.id]?.value || ""; + const attributes = {}; + + for (const attr of element.attributes) { + attributes[attr.name] = attr.value; + } + + sendMessageToBackgroundScript({ + action: "elementClicked", + tagName: tagName, + attributes: attributes, + x: x, + y: y, + value: value, + url: window.location.href, + timestamp: Date.now(), + }); +} + +function debounce(func, delay) { + let timerId; + return function (...args) { + if (timerId) { + clearTimeout(timerId); + } + timerId = setTimeout(() => { + func.apply(this, args); + timerId = null; + }, delay); + }; +} + +function handleDebouncedInput(event) { + const element = event.target; + const { x, y } = elements[element.id]; + const value = elements[element.id].element.value; + const tagName = element.tagName; + const attributes = {}; + + for (const attr of element.attributes) { + attributes[attr.name] = attr.value; + } + + sendMessageToBackgroundScript({ + action: "elementInput", + tagName: tagName, + attributes: attributes, + x: x, + y: y, + value: value, + url: window.location.href, + timestamp: Date.now(), + }); +} + +const debouncedInputHandler = debounce(handleDebouncedInput, 500); + +function handleElementInput(event) { + debouncedInputHandler(event); +} + +function addElement(element) { + const rect = element.getBoundingClientRect(); + const x = rect.left + window.scrollX; + const y = rect.top + window.scrollY; + const value = element.value; + if (!element.id) { + element.id = element.tagName + "_" + x + "_" + y; + } + elements[element.id] = { element, x, y, value }; + element.addEventListener("click", handleElementClick); + element.addEventListener("input", debounce(handleDebouncedInput, 500)); +} + +function addEventListeners() { + const elements = document.getElementsByTagName("*"); + + for (const element of elements) { + addElement(element); + } +} + +addEventListeners(); +captureDocumentState(); diff --git a/chrome/contentMutationObserverImplementation.js b/chrome/contentMutationObserverImplementation.js new file mode 100644 index 000000000..f914988d6 --- /dev/null +++ b/chrome/contentMutationObserverImplementation.js @@ -0,0 +1,72 @@ +/** + * @file content.js + * @description This file is injected into the web page and is responsible for + * capturing DOM changes and sending them to the background script. + */ + +let logged = false; +let ignoreAttributes = new Set(); + +/* + * Function to send a message to the background script + */ +function sendMessageToBackgroundScript(message) { + chrome.runtime.sendMessage(message); +} + +/* + * Function to capture initial document state and + * send it to the background script + */ +function captureDocumentState() { + const documentBody = document.body.outerHTML; + const documentHead = document.head.outerHTML; + const page_url = window.location.href; + + sendMessageToBackgroundScript({ + action: "captureDocumentState", + documentBody: documentBody, + documentHead: documentHead, + url: page_url, + timestamp: Date.now(), + }); +} + +const observer = new MutationObserver((mutations) => { + mutations.forEach((mutation) => { + const { type, target } = mutation; + const tagName = target.tagName.toLowerCase(); + const attributes = {}; + + for (const attr of target.attributes) { + attributes[attr.name] = attr.value; + } + + const rect = target.getBoundingClientRect(); + const x = rect.left + window.scrollX; + const y = rect.top + window.scrollY; + const value = target.value; + + sendMessageToBackgroundScript({ + action: "mutation", + type: type, + tagName: tagName, + attributes: attributes, + x: x, + y: y, + value: value, + url: window.location.href, + timestamp: Date.now(), + }); + }); +}); + +observer.observe(document.body, { + childList: true, + subtree: true, + attributes: true, + attributeOldValue: true, + characterData: true, + characterDataOldValue: true, +}); +captureDocumentState(); diff --git a/chrome/export_messages.py b/chrome/export_messages.py new file mode 100644 index 000000000..42cdf3f77 --- /dev/null +++ b/chrome/export_messages.py @@ -0,0 +1,64 @@ +"""Script to export messages from the database to an HTML file.""" + +import base64 +import datetime +import html +import json +import os +import sqlite3 + +from loguru import logger + +conn = sqlite3.connect("messages.db") +c = conn.cursor() + +c.execute("SELECT * FROM messages") + +messages = c.fetchall() + +conn.close() + +logo_path = "icons/logo.png" +with open(logo_path, "rb") as f: + logo_data = f.read() +logo_base64 = base64.b64encode(logo_data).decode("utf-8") + +page = f""" + +
| Table: messages | ||
| # | id INTEGER | message TEXT |
| {i+1} | {message[0]} | |
" + document_body + ""}], + } + ) + document_head = html.escape(d["message"]["documentHead"]) + children.append( + { + "id": "documentHead", + "children": [{"id": "
" + document_head + ""}], + } + ) + children.append({"id": "url", "children": [{"id": str(d["message"]["url"])}]}) + else: + tagName = d["message"]["tagName"] + children.append({"id": "tagName", "children": [{"id": str(tagName)}]}) + + if d["message"]["action"] == "elementInput": + value = d["message"]["value"] + children.append({"id": "value", "children": [{"id": str(value)}]}) + + x, y = d["message"]["x"], d["message"]["y"] + children.append({"id": "x", "children": [{"id": str(x)}]}) + children.append({"id": "y", "children": [{"id": str(y)}]}) + + attributes = d["message"]["attributes"] + children.append({"id": "attributes", "children": [{"id": str(attributes)}]}) + + tree[0]["children"] = children + ui.tree(tree, label_key="id")._props["default-expand-all"] = True + +ui.run(port=7777) diff --git a/openadapt/config.py b/openadapt/config.py index 9aab3fa92..5e222e301 100644 --- a/openadapt/config.py +++ b/openadapt/config.py @@ -88,6 +88,10 @@ "children", ], "PLOT_PERFORMANCE": True, + "SOCKET_PORT": 6001, + "SOCKET_AUTHKEY": b"openadapt", + "SOCKET_ADDRESS": "localhost", + "SOCKET_RETRY_INTERVAL": 5, # seconds # Calculate and save the difference between 2 neighboring screenshots "SAVE_SCREENSHOT_DIFF": False, "SPACY_MODEL_NAME": "en_core_web_trf", @@ -200,7 +204,8 @@ def obfuscate(val: str, pct_reveal: float = 0.1, char: str = "*") -> str: return rval -_OBFUSCATE_KEY_PARTS = ("KEY", "PASSWORD", "TOKEN") +_OBFUSCATE_KEY_PARTS = ("KEY", "PASSWORD", "TOKEN", "AUTHKEY", "PORT", "ADDRESS") + if multiprocessing.current_process().name == "MainProcess": for key, val in dict(locals()).items(): if not key.startswith("_") and key.isupper(): diff --git a/openadapt/crud.py b/openadapt/crud.py index 8493a4e36..2ce6f24ec 100644 --- a/openadapt/crud.py +++ b/openadapt/crud.py @@ -2,7 +2,6 @@ Module: crud.py """ - from typing import Any from loguru import logger @@ -12,6 +11,7 @@ from openadapt.db import BaseModel, Session from openadapt.models import ( ActionEvent, + BrowserEvent, MemoryStat, PerformanceStat, Recording, @@ -25,6 +25,7 @@ action_events = [] screenshots = [] window_events = [] +browser_events = [] performance_stats = [] memory_stats = [] @@ -125,6 +126,24 @@ def insert_window_event( _insert(event_data, WindowEvent, window_events) +def insert_browser_event( + recording_timestamp: int, event_timestamp: int, event_data: dict[str, Any] = None +) -> None: + """Insert a browser event into the database. + + Args: + recording_timestamp (int): The timestamp of the recording. + event_timestamp (int): The timestamp of the event. + event_data (dict): The data of the event. + """ + event_data = { + **event_data, + "timestamp": event_timestamp, + "recording_timestamp": recording_timestamp, + } + _insert(event_data, BrowserEvent, browser_events) + + def insert_perf_stat( recording_timestamp: int, event_type: str, @@ -395,6 +414,18 @@ def get_window_events(recording: Recording) -> list[WindowEvent]: return _get(WindowEvent, recording.timestamp) +def get_browser_events(recording: Recording) -> list[BrowserEvent]: + """Get browser events for a given recording. + + Args: + recording (Recording): recording object + + Returns: + List[BrowserEvent]: list of browser events + """ + return _get(BrowserEvent, recording.timestamp) + + def new_session() -> None: """Create a new database session. diff --git a/openadapt/events.py b/openadapt/events.py index 29d817163..b0135721c 100644 --- a/openadapt/events.py +++ b/openadapt/events.py @@ -40,6 +40,7 @@ def get_events( action_events = crud.get_action_events(recording) window_events = crud.get_window_events(recording) screenshots = crud.get_screenshots(recording) + browser_events = crud.get_browser_events(recording) raw_action_event_dicts = utils.rows2dicts(action_events) logger.debug(f"raw_action_event_dicts=\n{pformat(raw_action_event_dicts)}") @@ -48,10 +49,12 @@ def get_events( assert num_action_events > 0, "No action events found." num_window_events = len(window_events) num_screenshots = len(screenshots) + num_browser_events = len(browser_events) num_action_events_raw = num_action_events num_window_events_raw = num_window_events num_screenshots_raw = num_screenshots + num_browser_events_raw = num_browser_events duration_raw = action_events[-1].timestamp - action_events[0].timestamp num_process_iters = 0 @@ -62,26 +65,31 @@ def get_events( f"{num_action_events=} " f"{num_window_events=} " f"{num_screenshots=}" + f"{num_browser_events=}" ) ( action_events, window_events, screenshots, + browser_events, ) = process_events( action_events, window_events, screenshots, + browser_events, ) if ( len(action_events) == num_action_events and len(window_events) == num_window_events and len(screenshots) == num_screenshots + and len(browser_events) == num_browser_events ): break num_process_iters += 1 num_action_events = len(action_events) num_window_events = len(window_events) num_screenshots = len(screenshots) + num_browser_events = len(browser_events) if num_process_iters == MAX_PROCESS_ITERS: break @@ -102,6 +110,10 @@ def get_events( num_screenshots, num_screenshots_raw, ) + meta["num_browser_events"] = format_num( + num_browser_events, + num_browser_events_raw, + ) duration = action_events[-1].timestamp - action_events[0].timestamp if len(action_events) > 1: @@ -112,7 +124,7 @@ def get_events( duration = end_time - start_time logger.info(f"{duration=}") - return action_events # , window_events, screenshots + return action_events # , window_events, screenshots, browser_events def make_parent_event( @@ -135,9 +147,11 @@ def make_parent_event( "recording_timestamp": child.recording_timestamp, "window_event_timestamp": child.window_event_timestamp, "screenshot_timestamp": child.screenshot_timestamp, + "browser_event_timestamp": child.browser_event_timestamp, "recording": child.recording, "window_event": child.window_event, "screenshot": child.screenshot, + "browser_event": child.browser_event, } extra = extra or {} for key, val in extra.items(): @@ -685,6 +699,7 @@ def discard_unused_events( def process_events( action_events: list[models.ActionEvent], window_events: list[models.WindowEvent], + browser_events: list[models.BrowserEvent], screenshots: list[models.Screenshot], ) -> tuple[ list[models.ActionEvent], @@ -710,9 +725,14 @@ def process_events( num_action_events = len(action_events) num_window_events = len(window_events) num_screenshots = len(screenshots) - num_total = num_action_events + num_window_events + num_screenshots + num_browser_events = len(browser_events) + num_total = ( + num_action_events + num_window_events + num_screenshots + num_browser_events + ) logger.info( - f"before {num_action_events=} {num_window_events=} {num_screenshots=} " + "before" + f" {num_action_events=} {num_window_events=}" + f" {num_screenshots=} {num_browser_events=} " f"{num_total=}" ) process_fns = [ @@ -747,19 +767,31 @@ def process_events( action_events, "screenshot_timestamp", ) + browser_events = discard_unused_events( + browser_events, + action_events, + "browser_event_timestamp", + ) + num_action_events_ = len(action_events) num_window_events_ = len(window_events) num_screenshots_ = len(screenshots) - num_total_ = num_action_events_ + num_window_events_ + num_screenshots_ + num_browser_events_ = len(browser_events) + num_total_ = ( + num_action_events_ + num_window_events_ + num_screenshots_ + num_browser_events_ + ) pct_action_events = num_action_events_ / num_action_events pct_window_events = num_window_events_ / num_window_events pct_screenshots = num_screenshots_ / num_screenshots + # pct_browser_events = num_browser_events_ / num_browser_events pct_total = num_total_ / num_total logger.info( - f"after {num_action_events_=} {num_window_events_=} {num_screenshots_=} " + "after" + f" {num_action_events_=} {num_window_events_=}" + f" {num_screenshots_=} {num_browser_events_} " f"{num_total=}" ) logger.info( f"{pct_action_events=} {pct_window_events=} {pct_screenshots=} {pct_total=}" ) - return action_events, window_events, screenshots + return action_events, window_events, screenshots, browser_events diff --git a/openadapt/models.py b/openadapt/models.py index 7e315e91e..381422ff5 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -4,8 +4,8 @@ import io from loguru import logger -from PIL import Image, ImageChops from oa_pynput import keyboard +from PIL import Image, ImageChops import numpy as np import sqlalchemy as sa @@ -59,6 +59,11 @@ class Recording(db.Base): back_populates="recording", order_by="WindowEvent.timestamp", ) + browser_events = sa.orm.relationship( + "BrowserEvent", + back_populates="recording", + order_by="BrowserEvent.timestamp", + ) _processed_action_events = None @@ -83,6 +88,7 @@ class ActionEvent(db.Base): recording_timestamp = sa.Column(sa.ForeignKey("recording.timestamp")) screenshot_timestamp = sa.Column(sa.ForeignKey("screenshot.timestamp")) window_event_timestamp = sa.Column(sa.ForeignKey("window_event.timestamp")) + # browser_event_timestamp = sa.Column(sa.ForeignKey("browser_event.timestamp")) mouse_x = sa.Column(sa.Numeric(asdecimal=False)) mouse_y = sa.Column(sa.Numeric(asdecimal=False)) mouse_dx = sa.Column(sa.Numeric(asdecimal=False)) @@ -111,6 +117,7 @@ class ActionEvent(db.Base): recording = sa.orm.relationship("Recording", back_populates="action_events") screenshot = sa.orm.relationship("Screenshot", back_populates="action_event") window_event = sa.orm.relationship("WindowEvent", back_populates="action_events") + # browser_event = sa.orm.relationship("BrowserEvent", back_populates="action_event") # TODO: playback_timestamp / original_timestamp @@ -377,6 +384,28 @@ def get_active_window_event(cls: "WindowEvent") -> "WindowEvent": return WindowEvent(**window.get_active_window_data()) +class BrowserEvent(db.Base): + """Class representing a browser event in the database.""" + + __tablename__ = "browser_event" + + id = sa.Column(sa.Integer, primary_key=True) + recording_timestamp = sa.Column(sa.ForeignKey("recording.timestamp")) + recording = sa.orm.relationship("Recording", back_populates="browser_events") + message = sa.Column(sa.JSON) + timestamp = sa.Column(ForceFloat) + + +# recording = sa.orm.relationship("Recording", back_populates="browser_events") +# action_event = sa.orm.relationship("ActionEvent", back_populates="browser_event") + +# TODO: implement for extension +# @classmethod +# def get_active_browser_event(cls: Any) -> Any: +# """Get the active chrome tab window's DOM""" +# return BrowserEvent(**get_active_chrome_data()) + + class PerformanceStat(db.Base): """Class representing a performance statistic in the database.""" diff --git a/openadapt/playback.py b/openadapt/playback.py index 0e7889ee9..36a0b6450 100644 --- a/openadapt/playback.py +++ b/openadapt/playback.py @@ -1,7 +1,7 @@ """Utilities for playing back ActionEvents.""" from loguru import logger -from oa_pynput import mouse, keyboard +from oa_pynput import keyboard, mouse from openadapt.common import KEY_EVENTS, MOUSE_EVENTS from openadapt.models import ActionEvent diff --git a/openadapt/record.py b/openadapt/record.py index ad1b67564..863c61ea2 100644 --- a/openadapt/record.py +++ b/openadapt/record.py @@ -19,14 +19,14 @@ import tracemalloc from loguru import logger -from pympler import tracker from oa_pynput import keyboard, mouse +from pympler import tracker from tqdm import tqdm import fire import mss.tools import psutil -from openadapt import config, crud, utils, window +from openadapt import config, crud, sockets, utils, window from openadapt.extensions import synchronized_queue as sq from openadapt.models import ActionEvent @@ -162,6 +162,7 @@ def process_events( screen_write_q: sq.SynchronizedQueue, action_write_q: sq.SynchronizedQueue, window_write_q: sq.SynchronizedQueue, + browser_write_q: sq.SynchronizedQueue, perf_q: sq.SynchronizedQueue, recording_timestamp: int, terminate_event: multiprocessing.Event, @@ -173,6 +174,7 @@ def process_events( screen_write_q: A queue for writing screen events. action_write_q: A queue for writing action events. window_write_q: A queue for writing window events. + browser_write_q: A queue for writing browser events. perf_q: A queue for collecting performance data. recording_timestamp: The timestamp of the recording. terminate_event: An event to signal the termination of the process. @@ -183,8 +185,10 @@ def process_events( prev_event = None prev_screen_event = None prev_window_event = None + prev_browser_event = None prev_saved_screen_timestamp = 0 prev_saved_window_timestamp = 0 + prev_saved_browser_timestamp = 0 while not terminate_event.is_set() or not event_q.empty(): event = event_q.get() logger.trace(f"{event=}") @@ -198,6 +202,8 @@ def process_events( prev_screen_event = event elif event.type == "window": prev_window_event = event + elif event.type == "browser": + prev_browser_event = event elif event.type == "action": if prev_screen_event is None: logger.warning("Discarding action that came before screen") @@ -207,6 +213,12 @@ def process_events( continue event.data["screenshot_timestamp"] = prev_screen_event.timestamp event.data["window_event_timestamp"] = prev_window_event.timestamp + if prev_browser_event is not None: + event.data["browser_event_timestamp"] = ( + prev_browser_event.message["timestamp"] + if prev_browser_event is not None + else None + ) process_event( event, action_write_q, @@ -232,6 +244,19 @@ def process_events( perf_q, ) prev_saved_window_timestamp = prev_window_event.timestamp + if prev_browser_event is not None: + if prev_saved_browser_timestamp < prev_browser_event.msg["timestamp"]: + process_event( + prev_browser_event, + browser_write_q, + write_browser_event, + recording_timestamp, + perf_q, + ) + if prev_browser_event is not None: + prev_saved_browser_timestamp = prev_browser_event.message[ + "timestamp" + ] else: raise Exception(f"unhandled {event.type=}") del prev_event @@ -293,6 +318,23 @@ def write_window_event( perf_q.put((event.type, event.timestamp, utils.get_timestamp())) +def write_browser_event( + recording_timestamp: float, + event: Event, + perf_q: sq.SynchronizedQueue, +) -> None: + """Write a browser event to the database and update the performance queue. + + Args: + recording_timestamp: The timestamp of the recording. + event: A browser event to be written. + perf_q: A queue for collecting performance data. + """ + assert event.type == "browser", event + crud.insert_browser_event(recording_timestamp, event.timestamp, event.data) + perf_q.put((event.type, event.timestamp, utils.get_timestamp())) + + @trace(logger) def write_events( event_type: str, @@ -564,6 +606,64 @@ def read_window_events( prev_window_data = window_data +def read_browser_events( + event_q: queue.Queue, + terminate_event: multiprocessing.Event, + recording_timestamp: float, +) -> None: + """Read browser events and add them to the event queue. + + Args: + event_q: A queue for adding window events. + terminate_event: An event to signal the termination of the process. + recording_timestamp: The timestamp of the recording. + """ + utils.configure_logging(logger, LOG_LEVEL) + utils.set_start_time(recording_timestamp) + logger.info("starting") + conn = sockets.create_client_connection(config.SOCKET_PORT) + while not terminate_event.is_set(): + try: + if conn.closed: + conn = sockets.create_client_connection(config.SOCKET_PORT) + else: + logger.info("Waiting for message...") + msg = conn.recv() + logger.info(f"{msg=}") + + if msg is not None: + logger.info("Received message.") + browser_data = msg + logger.debug("queuing browser event for writing") + event_q.put( + Event( + utils.get_timestamp(), + "browser", + browser_data, + ) + ) + else: + logger.info("No message received or received None Type Message.") + except EOFError as exc: + logger.warning("Connection closed.") + logger.warning(exc) + break + # while True: + # try: + # conn = establish_connection() + # break + # except Exception as exc: + # logger.warning(f"Failed to reconnect: {exc}") + # time.sleep(config.SOCKET_RETRY_INTERVAL) + # except Exception as exc: + # logger.warning(f"Error during communication: {exc}") + # time.sleep(config.SOCKET_RETRY_INTERVAL) + # if conn: + # conn.close() + + logger.info("done") + + @trace(logger) def performance_stats_writer( perf_q: sq.SynchronizedQueue, @@ -822,6 +922,7 @@ def record( screen_write_q = sq.SynchronizedQueue() action_write_q = sq.SynchronizedQueue() window_write_q = sq.SynchronizedQueue() + browser_write_q = sq.SynchronizedQueue() # TODO: save write times to DB; display performance plot in visualize.py perf_q = sq.SynchronizedQueue() terminate_event = multiprocessing.Event() @@ -838,6 +939,10 @@ def record( term_pipe_parent_action, term_pipe_child_action, ) = multiprocessing.Pipe() + ( + term_pipe_parent_browser, + term_pipe_child_browser, + ) = multiprocessing.Pipe() window_event_reader = threading.Thread( target=read_window_events, @@ -845,6 +950,12 @@ def record( ) window_event_reader.start() + browser_event_reader = threading.Thread( + target=read_browser_events, + args=(event_q, terminate_event, recording_timestamp), + ) + browser_event_reader.start() + screen_event_reader = threading.Thread( target=read_screen_events, args=(event_q, terminate_event, recording_timestamp), @@ -870,6 +981,7 @@ def record( screen_write_q, action_write_q, window_write_q, + browser_write_q, perf_q, recording_timestamp, terminate_event, @@ -891,6 +1003,20 @@ def record( ) screen_event_writer.start() + browser_event_writer = multiprocessing.Process( + target=write_events, + args=( + "browser", + write_browser_event, + browser_write_q, + perf_q, + recording_timestamp, + terminate_event, + term_pipe_child_action, + ), + ) + browser_event_writer.start() + action_event_writer = multiprocessing.Process( target=write_events, args=( @@ -961,16 +1087,19 @@ def record( term_pipe_parent_window.send(window_write_q.qsize()) term_pipe_parent_action.send(action_write_q.qsize()) term_pipe_parent_screen.send(screen_write_q.qsize()) + term_pipe_parent_browser.send(browser_write_q.qsize()) logger.info("joining...") keyboard_event_reader.join() mouse_event_reader.join() screen_event_reader.join() window_event_reader.join() + browser_event_reader.join() event_processor.join() screen_event_writer.join() action_event_writer.join() window_event_writer.join() + browser_event_writer.join() terminate_perf_event.set() if PLOT_PERFORMANCE: diff --git a/openadapt/sockets.py b/openadapt/sockets.py new file mode 100644 index 000000000..1fde00be0 --- /dev/null +++ b/openadapt/sockets.py @@ -0,0 +1,228 @@ +"""Module for managing socket connections and communication.""" + +from multiprocessing import Queue +from multiprocessing.connection import Client, Connection, Listener +from typing import Any, Optional +import time + +from loguru import logger + +from openadapt import config + +client_by_port = {} +server_by_port = {} +queue_by_port = {} + + +def client_send_message(port: int, msg: Any) -> None: + """Send a message to the client connection associated with the given port. + + Args: + port: The port number associated with the client connection. + msg: The message to be sent. + + Returns: + None + """ + client_conn = client_by_port.get(port) + if client_conn: + client_conn.send(msg) + + +def server_send_message(port: int, msg: Any) -> None: + """Send a message to the server connection associated with the given port. + + Args: + port: The port number associated with the server connection. + msg: The message to be sent. + + Returns: + None + """ + server_conn = server_by_port.get(port) + if server_conn: + server_conn.send(msg) + + +def client_receive_message(port: int) -> Optional[str]: + """Receive a message from the client connection associated with the given port. + + Args: + port: The port number associated with the client connection. + + Returns: + The received message as a string, or None if no message is available. + """ + client_conn = client_by_port.get(port) + if client_conn: + try: + if message := client_conn.recv(): + return message + except Exception as exc: + logger.error("Connection was closed.") + logger.error(exc) + del client_by_port[port] + + +def server_receive_message(port: int) -> Optional[str]: + """Receive a message from the server connection associated with the given port. + + Args: + port: The port number associated with the server connection. + + Returns: + The received message as a string, or None if no message is available. + """ + server_conn = server_by_port.get(port) + while True: + if server_conn: + try: + message = server_conn.recv() + return message + except EOFError: + logger.warning("Connection closed. Reconnecting...") + while True: + try: + server_conn = create_server_connection(port) + break + except Exception as exc: + logger.warning(f"Failed to reconnect: {exc}") + time.sleep(config.SOCKET_RETRY_INTERVAL) + return None + + +def client_add_sink(port: int, queue: Queue) -> None: + """Add a sink queue to the specified client port. + + Args: + port: The port number to associate with the sink queue. + queue: The queue to be added as a sink. + + Raises: + ValueError: If the specified port already has a sink assigned. + + Returns: + None + """ + if port in queue_by_port: + raise ValueError(f"Port {port} already has a sink assigned.") + queue_by_port[port] = queue + + +def server_add_sink(port: int, queue: Queue) -> None: + """Add a sink queue to the specified server port. + + Args: + port: The port number to associate with the sink queue. + queue: The queue to be added as a sink. + + Raises: + ValueError: If the specified port already has a sink assigned. + + Returns: + None + """ + if port in queue_by_port: + raise ValueError(f"Port {port} already has a sink assigned.") + queue_by_port[port] = queue + + +_terminate_event: Optional[bool] = None + + +def set_terminate_event(terminate_event: bool) -> None: + """Set the termination event to control the event loop. + + Args: + terminate_event: The termination event object. + + Returns: + None + """ + global _terminate_event + _terminate_event = terminate_event + + +def create_client_connection(port: int) -> Connection: + """Create a client connection and establish a connection to the specified port. + + Args: + port: The port number to connect to. + + Returns: + The created client connection object. + """ + address = (config.SOCKET_ADDRESS, port) + conn = Client(address, authkey=config.SOCKET_AUTHKEY) + client_by_port[port] = conn + logger.info("Connected to the Client.") + return conn + + +def create_server_connection(port: int) -> Connection: + """Create and listen for connections on the specified port. + + Args: + port: The port number to bind the server connection to. + + Returns: + The created server connection object. + """ + address = (config.SOCKET_ADDRESS, port) + conn = Listener(address, authkey=config.SOCKET_AUTHKEY) + conn = conn.accept() + server_by_port[port] = conn + logger.info("Connected to the Server.") + return conn + + +def event_loop() -> None: + """The event loop for receiving and handling messages. + + Raises: + AssertionError: If `_terminate_event` is not set. + + Returns: + None + """ + assert _terminate_event, "You must call set_terminate_event" + while not _terminate_event.is_set(): + for port, client_conn in client_by_port.items(): + try: + message = client_conn.recv() # noqa: F841 + # if message: + # TODO: Handle message + + except EOFError: + # Handle connection closed or error + del client_by_port[port] + del queue_by_port[port] + + # for port, server_conn in server_by_port.items(): + # try: + # message = server_conn.recv() + # if message: + # queue = queue_by_port.get(port) + # if queue: + # queue.put(message) + # except EOFError: + # # Handle connection closed or error + # del server_by_port[port] + # del queue_by_port[port] + + +def server_sends(conn: Connection, message: Any) -> None: + """Send a message to the server connection associated with the given port.""" + if conn: + conn.send(message) + + +def client_receive(conn: Connection) -> Any: + """Receive a message from the client connection associated with the given port.""" + if conn: + try: + message = conn.recv() + return message + except EOFError: + logger.warning("Connection was closed.") + return None diff --git a/openadapt/strategies/mixins/sam.py b/openadapt/strategies/mixins/sam.py index 5f1a193ac..4be3ca835 100644 --- a/openadapt/strategies/mixins/sam.py +++ b/openadapt/strategies/mixins/sam.py @@ -14,9 +14,9 @@ class MyReplayStrategy(SAMReplayStrategyMixin): from loguru import logger from PIL import Image from segment_anything import ( - modeling, SamAutomaticMaskGenerator, SamPredictor, + modeling, sam_model_registry, ) import matplotlib.axes as axes diff --git a/openadapt/utils.py b/openadapt/utils.py index 49a0f0cdf..c087eb84f 100644 --- a/openadapt/utils.py +++ b/openadapt/utils.py @@ -10,6 +10,8 @@ import base64 import inspect import os +import signal +import socket import sys import threading import time @@ -21,6 +23,7 @@ import mss import mss.base import numpy as np +import psutil from openadapt import common, config from openadapt.db import BaseModel @@ -742,6 +745,55 @@ def strip_element_state(action_event: ActionEvent) -> ActionEvent: return action_event +def get_free_port() -> int: + """Get a free port number on the local machine. + + Returns: + An available free port number. + + Raises: + OSError: If a free port number cannot be obtained. + """ + # Create a temporary socket to find a free port + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + temp_socket.bind(("localhost", 0)) + _, port = temp_socket.getsockname() + temp_socket.close() + return port + + +def send_kill_signal(pid: int) -> None: + """Send a kill signal to the process identified by the PID. + + Args: + pid (int): The PID of the process. + + Raises: + OSError: If the kill signal cannot be sent. + """ + try: + # Send the kill signal (SIGTERM) to the process identified by the PID + os.kill(pid, signal.SIGTERM) + logger.info("Kill signal sent successfully.") + except OSError as e: + logger.info(f"Failed to send kill signal: {e}") + + +def get_pid_by_name(process_name: str) -> int: + """Get the PID of the process with the given name. + + Args: + process_name (str): The name of the process. + + Returns: + int: The PID of the process. + """ + for process in psutil.process_iter(["pid", "name"]): + if process.info["name"] == process_name: + return process.info["pid"] + return None + + def get_functions(name: str) -> dict: """Get a dictionary of function names to functions for all non-private functions. diff --git a/openadapt/visualize.py b/openadapt/visualize.py index bad2753ef..caf0bd5f2 100644 --- a/openadapt/visualize.py +++ b/openadapt/visualize.py @@ -264,10 +264,12 @@ def main(recording: Recording = None) -> bool: action_event_dict = row2dict(action_event) window_event_dict = row2dict(action_event.window_event) + browser_event_dict = row2dict(action_event.browser_event) if SCRUB: action_event_dict = scrub.scrub_dict(action_event_dict) window_event_dict = scrub.scrub_dict(window_event_dict) + browser_event_dict = scrub.scrub_dict(browser_event_dict) rows.append( [ @@ -297,6 +299,9 @@ def main(recording: Recording = None) -> bool: