diff --git a/main.py b/main.py index b60c27e..a259092 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ from dataclasses import dataclass from typing import Optional, Dict, Any, List -from flask import Flask, request, jsonify +from flask import Flask, request, jsonify, make_response from flask_httpauth import HTTPBasicAuth from flask_cors import CORS from cachetools import LRUCache @@ -24,6 +24,8 @@ import logging import sys from threading import Lock +from queue import Queue # Correct import for Queue + # Configure logging logging.basicConfig( @@ -64,19 +66,22 @@ def signal_handler(signum, frame): # Initialize LRU Cache cache = LRUCache(maxsize=10) -# Add near the top of the file, after imports but before app initialization -from threading import Lock -from typing import Dict, Optional -class ConnectionManager: - def __init__(self): - self._connections: Dict[str, duckdb.DuckDBPyConnection] = {} +# Connection Pool Implementation +class ConnectionPool: + def __init__(self, db_file, max_connections=5): + self.db_file = db_file + self.max_connections = max_connections + self._pool = Queue(maxsize=max_connections) self._lock = Lock() - - # Create default in-memory connection - self._default_conn = duckdb.connect(':memory:') - self._setup_extensions(self._default_conn) - + for _ in range(max_connections): + self._pool.put(self._create_connection()) + + def _create_connection(self): + conn = duckdb.connect(self.db_file) + self._setup_extensions(conn) + return conn + def _setup_extensions(self, conn: duckdb.DuckDBPyConnection): """Set up required extensions for a connection""" try: @@ -86,34 +91,68 @@ def _setup_extensions(self, conn: duckdb.DuckDBPyConnection): conn.load_extension("chsql_native") except Exception as e: logger.warning(f"Failed to initialize extensions: {e}") - + + def get_connection(self): + """Acquire a connection from the pool.""" + with self._lock: + return self._pool.get() + + def return_connection(self, conn): + """Return a connection to the pool.""" + with self._lock: + self._pool.put(conn) + + def close_all_connections(self): + """Close all connections in the pool.""" + while not self._pool.empty(): + conn = self._pool.get() + conn.close() + + +class ConnectionManager: + def __init__(self): + self._connection_pools: Dict[str, ConnectionPool] = {} + self._lock = Lock() + self._default_pool = ConnectionPool(':memory:') + def get_connection(self, auth_hash: Optional[str] = None) -> duckdb.DuckDBPyConnection: """Get or create a connection for the given auth hash""" if not auth_hash: - return self._default_conn - + return self._default_pool.get_connection() + with self._lock: - if auth_hash not in self._connections: + if auth_hash not in self._connection_pools: db_file = os.path.join(dbpath, f"{auth_hash}.db") - logger.info(f'Creating new connection for {db_file}') - conn = duckdb.connect(db_file) - self._setup_extensions(conn) - self._connections[auth_hash] = conn - return self._connections[auth_hash] + logger.info(f'Creating new connection pool for {db_file}') + pool = ConnectionPool(db_file) + self._connection_pools[auth_hash] = pool + return self._connection_pools[auth_hash].get_connection() + + def return_connection(self, conn: duckdb.DuckDBPyConnection, auth_hash: Optional[str] = None): + """Return a connection to the pool""" + with self._lock: + if not auth_hash: + self._default_pool.return_connection(conn) + elif auth_hash in self._connection_pools: + self._connection_pools[auth_hash].return_connection(conn) + else: + logger.warning(f'Pool not found for {auth_hash}, closing connection') + conn.close() + # Create global connection manager connection_manager = ConnectionManager() -# Replace the global conn variable with a property that uses the connection manager + def get_current_connection() -> duckdb.DuckDBPyConnection: """Get the current connection based on authentication""" - auth = request.authorization if hasattr(request, 'authorization') else None - if auth and auth.username and auth.password: - user_pass_hash = hashlib.sha256((auth.username + auth.password).encode()).hexdigest() + auth_obj = request.authorization if hasattr(request, 'authorization') else None + if auth_obj and auth_obj.username and auth_obj.password: + user_pass_hash = hashlib.sha256((auth_obj.username + auth_obj.password).encode()).hexdigest() return connection_manager.get_connection(user_pass_hash) return connection_manager.get_connection() -# Remove the global conn variable and replace with this property + @property def conn(): return get_current_connection() @@ -124,10 +163,9 @@ def verify(username, password): if not (username and password): logger.debug('Using stateless session') return True - + logger.info(f"Using http auth: {username}:{password}") user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() - # Just verify the connection exists/can be created connection_manager.get_connection(user_pass_hash) return True @@ -138,8 +176,8 @@ def convert_to_ndjson(result): ndjson_lines = [] for row in data: row_dict = {columns[i][0]: row[i] for i in range(len(columns))} - ndjson_lines.append(json.dumps(row_dict)) - return '\n'.join(ndjson_lines).encode() + ndjson_lines.append(json.dumps(row_dict, ensure_ascii=False)) + return '\n'.join(ndjson_lines).encode('utf-8') def convert_to_clickhouse_jsoncompact(result, query_time): @@ -157,7 +195,7 @@ def convert_to_clickhouse_jsoncompact(result, query_time): "bytes_read": sum(len(str(item)) for row in data for item in row) } } - return json.dumps(json_result) + return json.dumps(json_result, ensure_ascii=False) def convert_to_clickhouse_json(result, query_time): @@ -178,7 +216,7 @@ def convert_to_clickhouse_json(result, query_time): "bytes_read": sum(len(str(item)) for row in data for item in row) } } - return json.dumps(json_result) + return json.dumps(json_result, ensure_ascii=False) def convert_to_csv_tsv(result, delimiter=','): @@ -190,12 +228,12 @@ def convert_to_csv_tsv(result, delimiter=','): for row in data: line = delimiter.join([str(item) for item in row]) lines.append(line) - return '\n'.join(lines).encode() + return '\n'.join(lines).encode('utf-8') -def handle_insert_query(query, format, data=None, conn=None): - if conn is None: - conn = get_current_connection() +def handle_insert_query(query, format, data=None, current_conn=None): + if current_conn is None: + current_conn = get_current_connection() table_name = query.split("INTO")[1].split()[0].strip() temp_file_name = None if format.lower() == 'jsoneachrow' and data is not None: @@ -203,9 +241,9 @@ def handle_insert_query(query, format, data=None, conn=None): if temp_file_name: try: ingest_query = f"COPY {table_name} FROM '{temp_file_name}' (FORMAT 'json')" - conn.execute(ingest_query) + current_conn.execute(ingest_query) except Exception as e: - return b"", str(e).encode() + return b"", str(e).encode('utf-8') finally: os.remove(temp_file_name) return b"Ok", b"" @@ -221,10 +259,10 @@ def save_to_tempfile(data): def duckdb_query_with_errmsg(query, format='JSONCompact', data=None, request_method="GET"): + current_conn = None try: - # Get connection for current request current_conn = get_current_connection() - + if request_method == "POST" and query.strip().lower().startswith('insert into') and data: return handle_insert_query(query, format, data, current_conn) start_time = time.time() @@ -243,10 +281,16 @@ def duckdb_query_with_errmsg(query, format='JSONCompact', data=None, request_met else: output = result.fetchall() if isinstance(output, list): - output = json.dumps(output).encode() + output = json.dumps(output, ensure_ascii=False).encode('utf-8') return output, b"" except Exception as e: - return b"", str(e).encode() + logger.error(f"Query execution error: {e}", exc_info=True) + return b"", str(e).encode('utf-8') + finally: + if current_conn: + auth_obj = request.authorization if hasattr(request, 'authorization') else None + user_pass_hash = hashlib.sha256((auth_obj.username + auth_obj.password).encode()).hexdigest() if auth_obj and auth_obj.username and auth_obj.password else None + connection_manager.return_connection(current_conn, user_pass_hash) def sanitize_query(query): @@ -271,11 +315,13 @@ def clickhouse(): query, sanitized_format = sanitize_query(query) if sanitized_format: format = sanitized_format - print( + logger.debug( f"Received request: method={request.method}, query={query}, format={format}, database={database}") if query_id is not None and not query: if query_id in cache: - return cache[query_id], 200 + cached_response = cache[query_id] + logger.debug(f"Cache hit for query_id: {query_id}") + return cached_response, 200 if not query: return app.send_static_file('index.html') if request.method == "POST": @@ -286,21 +332,33 @@ def clickhouse(): query.strip(), format, data, request.method) if query_id and len(errmsg) == 0: cache[query_id] = result + logger.debug(f"Cache set for query_id: {query_id}") if len(errmsg) == 0: if request.method == "HEAD": response = app.response_class(status=200) response.headers['Content-Type'] = 'application/json' response.headers['Accept-Ranges'] = 'bytes' content_length = len(result) if isinstance( - result, bytes) else len(result.decode()) + result, bytes) else len(result.decode('utf-8')) response.headers['Content-Length'] = content_length return response - return result, 200 - if len(result) > 0: - print("warning:", errmsg) - return result, 200 - print("Error occurred:", errmsg) - return errmsg, 400 + response = make_response(result, 200) + content_type = 'application/json' + if format.lower() == 'csv': + content_type = 'text/csv' + elif format.lower() == 'tsv': + content_type = 'text/tab-separated-values' + elif format.lower() == 'jsoneachrow' or format.lower() == 'json' or format.lower() == 'jsoncompact': + content_type = 'application/json' + response.headers['Content-Type'] = content_type + return response + if len(errmsg) > 0: + logger.warning(f"Query warning: {errmsg}") + response = make_response(result, 200) + response.headers['Content-Type'] = 'application/json' + return response + logger.error(f"Query error: {errmsg}") + return errmsg.decode('utf-8'), 400 @app.route('/', methods=["POST"]) @@ -314,7 +372,9 @@ def play(): query_id = request.args.get('query_id', default=None, type=str) if query_id is not None and not query: if query_id in cache: - return cache[query_id], 200 + cached_response = cache[query_id] + logger.debug(f"Cache hit for query_id: {query_id}") + return cached_response, 200 if query is None: query = "" if body is not None: @@ -327,14 +387,26 @@ def play(): query, sanitized_format = sanitize_query(query) if sanitized_format: format = sanitized_format - print("DEBUG POST", query, format) + logger.debug(f"DEBUG POST query: {query}, format: {format}") result, errmsg = duckdb_query_with_errmsg(query.strip(), format) if len(errmsg) == 0: - return result, 200 - if len(result) > 0: - print("warning:", errmsg) - return result, 200 - return errmsg, 400 + response = make_response(result, 200) + content_type = 'application/json' + if format.lower() == 'csv': + content_type = 'text/csv' + elif format.lower() == 'tsv': + content_type = 'text/tab-separated-values' + elif format.lower() == 'jsoneachrow' or format.lower() == 'json' or format.lower() == 'jsoncompact': + content_type = 'application/json' + response.headers['Content-Type'] = content_type + return response + if len(errmsg) > 0: + logger.warning(f"Query warning: {errmsg}") + response = make_response(result, 200) + response.headers['Content-Type'] = 'application/json' + return response + logger.error(f"Query error: {errmsg}") + return errmsg.decode('utf-8'), 400 @app.route('/play', methods=["GET"]) @@ -358,17 +430,15 @@ def handle_404(e): flight_port = int(os.getenv('FLIGHT_PORT', 8815)) path = os.getenv('DATA', '.duckdb_data') + def parse_ticket(ticket): try: - # Try to decode the ticket as a JSON object ticket_obj = json.loads(ticket.ticket.decode("utf-8")) if isinstance(ticket_obj, str): - # If the JSON object is a string, parse it again ticket_obj = json.loads(ticket_obj) if "query" in ticket_obj: return ticket_obj["query"] except (json.JSONDecodeError, AttributeError): - # If decoding fails or "query" is not in the object, return the ticket as a string return ticket.ticket.decode("utf-8") @@ -405,7 +475,6 @@ def deserialize(cls, data: bytes) -> 'FlightSchemaMetadata': return cls(**metadata) -# Add this helper class for schema serialization @dataclass class SerializedSchema: schema: str @@ -428,9 +497,7 @@ def to_dict(self) -> Dict: } -# Patch the main function where the ticket is processed if __name__ == '__main__': - # Set up signal handlers signal.signal(signal.SIGINT, signal_handler) def run_flask(): @@ -445,10 +512,11 @@ def run_flask(): def run_flight_server(): """Run Flight server""" + class HeaderMiddleware(flight.ServerMiddleware): def __init__(self): self.authorization = None - self.headers = {} # Store all headers + self.headers = {} def call_completed(self, exception=None): pass @@ -458,16 +526,11 @@ def start_call(self, info, headers): logger.debug(f"Info received: {info}") logger.debug(f"Headers received: {headers}") middleware = HeaderMiddleware() - - # Store all headers in the middleware middleware.headers = headers - if "authorization" in headers: - # Get first value from list - auth = headers["authorization"][0] - auth = auth[7:] if auth.startswith('Bearer ') else auth - middleware.authorization = auth - + auth_header = headers["authorization"][0] + auth_header = auth_header[7:] if auth_header.startswith('Bearer ') else auth_header + middleware.authorization = auth_header return middleware class DuckDBFlightServer(flight.FlightServerBase): @@ -477,22 +540,19 @@ def __init__(self, location=f"grpc://{flight_host}:{flight_port}", db_path=":mem self._location = location logger.info(f"Initializing Flight server at {location}") self.conn = duckdb.connect(db_path) - - # Define schema for catalog listing + catalog_schema = pa.schema([ ('catalog_name', pa.string()), ('schema_name', pa.string()), ('description', pa.string()) ]) - - # Define schema for table listing table_schema = pa.schema([ ('table_name', pa.string()), ('schema_name', pa.string()), ('catalog_name', pa.string()), ('table_type', pa.string()) ]) - + self.flights = [ { "command": "show_databases", @@ -521,264 +581,214 @@ def __init__(self, location=f"grpc://{flight_host}:{flight_port}", db_path=":mem ] def _get_connection_from_context(self, context) -> duckdb.DuckDBPyConnection: - """Get the appropriate connection based on Flight context""" + """Get the appropriate connection based on Flight context using ConnectionManager""" + middleware = context.get_middleware("auth") + auth_header = middleware.authorization if middleware and middleware.authorization else None + user_pass_hash = None + if auth_header: + if isinstance(auth_header, str): + if ':' in auth_header: + username, password = auth_header.split(':', 1) + user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() + else: + user_pass_hash = auth_header + conn = connection_manager.get_connection(user_pass_hash) + return conn + + def return_connection_to_pool(self, context, conn: duckdb.DuckDBPyConnection): + """Return the connection to the pool after use.""" middleware = context.get_middleware("auth") - if middleware and middleware.authorization: - auth_header = middleware.authorization + auth_header = middleware.authorization if middleware and middleware.authorization else None + user_pass_hash = None + if auth_header: if isinstance(auth_header, str): if ':' in auth_header: username, password = auth_header.split(':', 1) user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() else: user_pass_hash = auth_header - return connection_manager.get_connection(user_pass_hash) - return connection_manager.get_connection() + connection_manager.return_connection(conn, user_pass_hash) def do_action(self, context, action): """Handle Flight actions""" logger.debug(f"Action Request: {action}") - - if action.type == "list_schemas": - try: - # Parse the request body - body = json.loads(action.body.to_pybytes().decode('utf-8')) - catalog_name = body.get("catalog_name", "main") - - # Query schemas from DuckDB - query = """ - SELECT - schema_name as schema, - 'DuckDB Schema' as description, - '{}' as tags, - 'table' as type - FROM information_schema.schemata - WHERE catalog_name = ? - """ - result = self.conn.execute(query, [catalog_name]).fetchall() - - # Convert results to SerializedSchema objects - schemas = [] - for row in result: - schema = SerializedSchema( - schema=catalog_name, - description=row[1], - tags=json.loads(row[2]), - contents={"url": None, "sha256": None, "serialized": None}, - type=row[3] - ) - schemas.append(schema.to_dict()) - - # Create the catalog root structure - catalog_root = { - "contents": { - "url": None, - "sha256": None, - "serialized": None - }, - "schemas": schemas - } - - # Serialize with msgpack - packed_data = msgpack.packb(catalog_root) - - # Compress with zstd - compressor = zstd.ZstdCompressor() - compressed_data = compressor.compress(packed_data) - - # Create result with decompressed length and compressed data - decompressed_length = len(packed_data) - length_bytes = decompressed_length.to_bytes(4, byteorder='little') - - # Return results as flight.Result objects - yield flight.Result(pa.py_buffer(length_bytes)) - yield flight.Result(pa.py_buffer(compressed_data)) - - except Exception as e: - logger.exception("Error in list_schemas action") - raise flight.FlightUnavailableError(f"Failed to list schemas: {str(e)}") - - elif action.type == "create_schema": - try: - # Set up authenticated connection first - middleware = context.get_middleware("auth") - if middleware and middleware.authorization: - auth_header = middleware.authorization - logger.info(f"Using authorization from middleware: {auth_header}") - if isinstance(auth_header, str): - if ':' in auth_header: - username, password = auth_header.split(':', 1) - user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() - else: - user_pass_hash = auth_header - - db_file = os.path.join(dbpath, f"{user_pass_hash}.db") - logger.info(f'Using database file: {db_file}') - self.conn = duckdb.connect(db_file) - - # Try msgpack first + current_conn = None + try: + current_conn = self._get_connection_from_context(context) + if action.type == "list_schemas": try: - body = msgpack.unpackb(action.body.to_pybytes()) - except: - # Fall back to UTF-8 if msgpack fails - body = action.body.to_pybytes().decode('utf-8') - - # Extract schema name from the full path (e.g., deltalake.test1 -> test1) - schema_name = body.split('.')[-1] if '.' in body else body - - # Create schema in the authenticated database - query = f"CREATE SCHEMA IF NOT EXISTS {schema_name}" - logger.debug(f"Creating schema with query: {query}") - self.conn.execute(query) - - except Exception as e: - logger.exception("Error in create_schema action") - raise flight.FlightUnavailableError(f"Failed to create schema: {str(e)}") - - elif action.type == "create_table": - try: - # Set up authenticated connection first - middleware = context.get_middleware("auth") - if middleware and middleware.authorization: - auth_header = middleware.authorization - logger.info(f"Using authorization from middleware: {auth_header}") - if isinstance(auth_header, str): - if ':' in auth_header: - username, password = auth_header.split(':', 1) - user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() - else: - user_pass_hash = auth_header - - db_file = os.path.join(dbpath, f"{user_pass_hash}.db") - logger.info(f'Using database file: {db_file}') - self.conn = duckdb.connect(db_file) - - # Get the raw bytes and parse table info - body_bytes = action.body.to_pybytes() - logger.debug(f"Raw table creation bytes: {body_bytes.hex()}") - + body = json.loads(action.body.to_pybytes().decode('utf-8')) + catalog_name = body.get("catalog_name", "main") + + query = """ + SELECT + schema_name as schema, + 'DuckDB Schema' as description, + '{}' as tags, + 'table' as type + FROM information_schema.schemata + WHERE catalog_name = ? + """ + result = current_conn.execute(query, [catalog_name]).fetchall() + + schemas = [] + for row in result: + schema = SerializedSchema( + schema=catalog_name, + description=row[1], + tags=json.loads(row[2]), + contents={"url": None, "sha256": None, "serialized": None}, + type=row[3] + ) + schemas.append(schema.to_dict()) + + catalog_root = { + "contents": { + "url": None, + "sha256": None, + "serialized": None + }, + "schemas": schemas + } + + packed_data = msgpack.packb(catalog_root) + compressor = zstd.ZstdCompressor() + compressed_data = compressor.compress(packed_data) + decompressed_length = len(packed_data) + length_bytes = decompressed_length.to_bytes(4, byteorder='little') + + yield flight.Result(pa.py_buffer(length_bytes)) + yield flight.Result(pa.py_buffer(compressed_data)) + + except Exception as e: + logger.exception("Error in list_schemas action") + raise flight.FlightUnavailableError(f"Failed to list schemas: {str(e)}") + + elif action.type == "create_schema": try: - # Parse Arrow IPC format - reader = pa.ipc.open_stream(pa.py_buffer(body_bytes)) - table = reader.read_all() - - logger.debug(f"Arrow schema: {table.schema}") - logger.debug(f"Column names: {table.column_names}") - - # Get metadata from schema - schema_metadata = table.schema.metadata - catalog_name = schema_metadata.get(b'catalog_name', b'').decode('utf-8') - schema_name = schema_metadata.get(b'schema_name', b'').decode('utf-8') - table_name = schema_metadata.get(b'table_name', b'').decode('utf-8') - - # Extract actual schema name (e.g., test1 from deltalake.test1) - actual_schema = schema_name.split('.')[-1] if '.' in schema_name else schema_name - - # Get columns from schema - columns = [] - for field in table.schema: - columns.append({ - 'name': field.name, - 'type': self._arrow_to_duckdb_type(field.type) - }) - - logger.debug(f"Parsed metadata - catalog: {catalog_name}, schema: {schema_name}, table: {table_name}") - logger.debug(f"Columns: {columns}") - - if not actual_schema or not table_name: - raise flight.FlightUnavailableError( - f"Missing schema_name or table_name in request. Found catalog={catalog_name}, schema={schema_name}, table={table_name}") - - column_defs = [] - for col in columns: - name = col.get('name') - type_ = col.get('type') - if not name or not type_: - raise flight.FlightUnavailableError(f"Invalid column definition: {col}") - column_defs.append(f"{name} {type_}") - - # Create table in the authenticated database - query = f"""CREATE TABLE IF NOT EXISTS {actual_schema}.{table_name} ( - {', '.join(column_defs)} - )""" - - logger.debug(f"Creating table with query: {query}") - self.conn.execute(query) - - # Create and return FlightInfo for the newly created table - schema_metadata = FlightSchemaMetadata( - type="table", - catalog=catalog_name, - schema=schema_name, - name=table_name, - comment=None, - input_schema=table.schema - ) - - flight_info = flight.FlightInfo( - table.schema, - flight.FlightDescriptor.for_path(table_name.encode()), - [flight.FlightEndpoint( - ticket=flight.Ticket( - f"SELECT * FROM {catalog_name}.{schema_name}.{table_name}".encode() - ), - locations=[self._location] - )], - -1, # total_records - -1, # total_bytes - schema_metadata.serialize() - ) - - yield flight.Result(flight_info.serialize()) - + try: + body = msgpack.unpackb(action.body.to_pybytes()) + except: + body = action.body.to_pybytes().decode('utf-8') + + schema_name = body.split('.')[-1] if '.' in body else body + query = f"CREATE SCHEMA IF NOT EXISTS {schema_name}" + logger.debug(f"Creating schema with query: {query}") + current_conn.execute(query) + except Exception as e: - logger.exception("Failed to parse Arrow IPC data") - raise flight.FlightUnavailableError(f"Invalid Arrow IPC data in request: {str(e)}") - - except Exception as e: - logger.exception("Error in create_table action") - raise flight.FlightUnavailableError(f"Failed to create table: {str(e)}") - - else: - raise flight.FlightUnavailableError(f"Action '{action.type}' not implemented") + logger.exception("Error in create_schema action") + raise flight.FlightUnavailableError(f"Failed to create schema: {str(e)}") + + elif action.type == "create_table": + try: + body_bytes = action.body.to_pybytes() + logger.debug(f"Raw table creation bytes: {body_bytes.hex()}") + + try: + reader = pa.ipc.open_stream(pa.py_buffer(body_bytes)) + table = reader.read_all() + + logger.debug(f"Arrow schema: {table.schema}") + logger.debug(f"Column names: {table.column_names}") + + schema_metadata = table.schema.metadata + catalog_name = schema_metadata.get(b'catalog_name', b'').decode('utf-8') + schema_name = schema_metadata.get(b'schema_name', b'').decode('utf-8') + table_name = schema_metadata.get(b'table_name', b'').decode('utf-8') + + actual_schema = schema_name.split('.')[-1] if '.' in schema_name else schema_name + + columns = [] + for field in table.schema: + columns.append({ + 'name': field.name, + 'type': self._arrow_to_duckdb_type(field.type) + }) + + logger.debug(f"Parsed metadata - catalog: {catalog_name}, schema: {schema_name}, table: {table_name}") + logger.debug(f"Columns: {columns}") + + if not actual_schema or not table_name: + raise flight.FlightUnavailableError( + f"Missing schema_name or table_name in request. Found catalog={catalog_name}, schema={schema_name}, table={table_name}") + + column_defs = [] + for col in columns: + name = col.get('name') + type_ = col.get('type') + if not name or not type_: + raise flight.FlightUnavailableError(f"Invalid column definition: {col}") + column_defs.append(f"{name} {type_}") + + query = f"""CREATE TABLE IF NOT EXISTS {actual_schema}.{table_name} ( + {', '.join(column_defs)} + )""" + + logger.debug(f"Creating table with query: {query}") + current_conn.execute(query) + + schema_metadata = FlightSchemaMetadata( + type="table", + catalog=catalog_name, + schema=schema_name, + name=table_name, + comment=None, + input_schema=table.schema + ) + + flight_info = flight.FlightInfo( + table.schema, + flight.FlightDescriptor.for_path(table_name.encode()), + [flight.FlightEndpoint( + ticket=flight.Ticket( + f"SELECT * FROM {catalog_name}.{schema_name}.{table_name}".encode() + ), + locations=[self._location] + )], + -1, + -1, + schema_metadata.serialize() + ) + + yield flight.Result(flight_info.serialize()) + + except Exception as e: + logger.exception("Failed to parse Arrow IPC data") + raise flight.FlightUnavailableError(f"Invalid Arrow IPC data in request: {str(e)}") + + except Exception as e: + logger.exception("Error in create_table action") + raise flight.FlightUnavailableError(f"Failed to create table: {str(e)}") + + else: + raise flight.FlightUnavailableError(f"Action '{action.type}' not implemented") + + except Exception as e: + logger.exception("Error in do_action") + raise flight.FlightUnavailableError(f"Action failed: {str(e)}") + finally: + if current_conn: + self.return_connection_to_pool(context, current_conn) def do_get(self, context, ticket): """Handle 'GET' requests""" logger.debug("do_get called") + current_conn = None try: - # Access middleware and set up connection - middleware = context.get_middleware("auth") - if middleware and middleware.authorization: - auth_header = middleware.authorization - logger.info(f"Using authorization from middleware: {auth_header}") - if isinstance(auth_header, str): - if ':' in auth_header: - username, password = auth_header.split(':', 1) - user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() - else: - user_pass_hash = auth_header - - db_file = os.path.join(dbpath, f"{user_pass_hash}.db") - logger.info(f'Using database file: {db_file}') - self.conn = duckdb.connect(db_file) + current_conn = self._get_connection_from_context(context) - except Exception as e: - logger.debug(f"Middleware access error: {e}") - - query = parse_ticket(ticket) - - # Rewrite query to use local schema instead of deltalake catalog - if query.lower().startswith("select"): - # Extract schema and table from deltalake.schema.table pattern - parts = query.split() - for i, part in enumerate(parts): - if "deltalake." in part.lower(): - # Remove the catalog prefix, keeping schema and table - parts[i] = part.split(".", 1)[1] - query = " ".join(parts) - - logger.info(f"Executing query: {query}") - try: - result_table = self.conn.execute(query).fetch_arrow_table() + query = parse_ticket(ticket) + + if query.lower().startswith("select"): + parts = query.split() + for i, part in enumerate(parts): + if "deltalake." in part.lower(): + parts[i] = part.split(".", 1)[1] + query = " ".join(parts) + + logger.info(f"Executing query: {query}") + result_table = current_conn.execute(query).fetch_arrow_table() batches = result_table.to_batches(max_chunksize=1024) if not batches: logger.debug("No data in result") @@ -788,38 +798,37 @@ def do_get(self, context, ticket): return flight.RecordBatchStream(pa.Table.from_batches(batches)) except Exception as e: logger.exception(f"Query execution error: {str(e)}") - raise + raise flight.FlightUnavailableError(f"Query execution failed: {str(e)}") + finally: + if current_conn: + self.return_connection_to_pool(context, current_conn) def do_put(self, context, descriptor, reader, writer): """Handle 'PUT' requests""" - table = reader.read_all() - table_name = descriptor.path[0].decode('utf-8') - self.conn.register("temp_table", table) - self.conn.execute( - f"INSERT INTO {table_name} SELECT * FROM temp_table") + current_conn = None + try: + current_conn = self._get_connection_from_context(context) + table = reader.read_all() + table_name = descriptor.path[0].decode('utf-8') + current_conn.register("temp_table", table) + current_conn.execute( + f"INSERT INTO {table_name} SELECT * FROM temp_table") + except Exception as e: + logger.exception("Error in do_put") + raise flight.FlightUnavailableError(f"Put operation failed: {str(e)}") + finally: + if current_conn: + self.return_connection_to_pool(context, current_conn) def get_flight_info(self, context, descriptor): """Implement 'get_flight_info'""" + current_conn = None try: - # Set up authenticated connection - middleware = context.get_middleware("auth") - if middleware and middleware.authorization: - auth_header = middleware.authorization - logger.info(f"Using authorization from middleware: {auth_header}") - if isinstance(auth_header, str): - if ':' in auth_header: - username, password = auth_header.split(':', 1) - user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() - else: - user_pass_hash = auth_header - - db_file = os.path.join(dbpath, f"{user_pass_hash}.db") - logger.info(f'Using database file: {db_file}') - self.conn = duckdb.connect(db_file) + current_conn = self._get_connection_from_context(context) if descriptor.command is not None: query = descriptor.command.decode("utf-8") - result_table = self.conn.execute(query).fetch_arrow_table() + result_table = current_conn.execute(query).fetch_arrow_table() schema = result_table.schema endpoints = [flight.FlightEndpoint( ticket=flight.Ticket(query.encode("utf-8")), @@ -832,7 +841,7 @@ def get_flight_info(self, context, descriptor): query = flight_info["ticket"].ticket.decode("utf-8") logger.info(f"Attempting flight with query: {query}") try: - result_table = self.conn.execute(query).fetch_arrow_table() + result_table = current_conn.execute(query).fetch_arrow_table() schema = result_table.schema endpoints = [flight.FlightEndpoint( ticket=flight.Ticket(query.encode("utf-8")), @@ -849,61 +858,47 @@ def get_flight_info(self, context, descriptor): except Exception as e: logger.exception("Error in get_flight_info") raise flight.FlightUnavailableError(f"Failed to get flight info: {str(e)}") + finally: + if current_conn: + self.return_connection_to_pool(context, current_conn) def list_flights(self, context, criteria): """List available flights with metadata""" logger.info("Listing available flights") - + current_conn = None try: - # Set up authenticated connection + current_conn = self._get_connection_from_context(context) middleware = context.get_middleware("auth") - if middleware and middleware.authorization: - auth_header = middleware.authorization - logger.info(f"Using authorization from middleware: {auth_header}") - if isinstance(auth_header, str): - if ':' in auth_header: - username, password = auth_header.split(':', 1) - user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() - else: - user_pass_hash = auth_header - - db_file = os.path.join(dbpath, f"{user_pass_hash}.db") - logger.info(f'Using database file: {db_file}') - self.conn = duckdb.connect(db_file) - headers = middleware.headers if middleware else {} catalog_filter = None schema_filter = None - - # Extract filters from headers + if "airport-list-flights-filter-catalog" in headers: catalog_filter = headers["airport-list-flights-filter-catalog"][0] if "airport-list-flights-filter-schema" in headers: schema_filter = headers["airport-list-flights-filter-schema"][0] - + logger.debug(f"Filtering flights - catalog: {catalog_filter}, schema: {schema_filter}") - + if catalog_filter and schema_filter: - # Query for tables in the specific catalog and schema query = f""" - SELECT + SELECT table_name, table_schema as schema_name, table_catalog as catalog_name, table_type, column_name, - data_type - FROM information_schema.tables + data_type, + is_nullable + FROM information_schema.tables JOIN information_schema.columns USING (table_catalog, table_schema, table_name) WHERE table_catalog = '{catalog_filter}' AND table_schema = '{schema_filter}' ORDER BY table_name, ordinal_position """ - + try: - result = self.conn.execute(query).fetchall() - - # Group results by table + result = current_conn.execute(query).fetchall() tables = {} for row in result: table_name = row[0] @@ -916,27 +911,17 @@ def list_flights(self, context, criteria): } tables[table_name]['columns'].append({ 'name': row[4], - 'type': row[5] + 'type': row[5], + 'nullable': row[6] == 'YES' }) - - # Create flight info for each table for table_name, table_info in tables.items(): - # Create Arrow schema from columns fields = [] for col in table_info['columns']: - # Convert DuckDB type to Arrow type - arrow_type = pa.string() # Default to string - if 'INT' in col['type'].upper(): - arrow_type = pa.int64() - elif 'DOUBLE' in col['type'].upper() or 'FLOAT' in col['type'].upper(): - arrow_type = pa.float64() - elif 'BOOLEAN' in col['type'].upper(): - arrow_type = pa.bool_() - fields.append(pa.field(col['name'], arrow_type)) - + arrow_type = self._duckdb_to_arrow_type(col['type']) + field = pa.field(col['name'], arrow_type, nullable=col['nullable']) + fields.append(field) + schema = pa.schema(fields) - - # Create metadata for the table schema_metadata = FlightSchemaMetadata( type="table", catalog=table_info['catalog_name'], @@ -945,9 +930,8 @@ def list_flights(self, context, criteria): comment=None, input_schema=schema ) - - # Create flight info - flight_info = flight.FlightInfo( + + flight_info_obj = flight.FlightInfo( schema, flight.FlightDescriptor.for_path([table_name.encode()]), [flight.FlightEndpoint( @@ -956,19 +940,18 @@ def list_flights(self, context, criteria): ), locations=[self._location] )], - -1, # total_records - -1, # total_bytes + -1, + -1, schema_metadata.serialize() ) - - yield flight_info - + yield flight_info_obj + + except Exception as e: logger.exception(f"Error querying tables: {str(e)}") raise flight.FlightUnavailableError(f"Failed to list tables: {str(e)}") - + else: - # Return default flights when no specific filters for flight_info in self.flights: schema_metadata = FlightSchemaMetadata( type="table", @@ -978,12 +961,46 @@ def list_flights(self, context, criteria): comment=None, input_schema=flight_info["schema"] ) - - yield flight_info - + flight_info_obj = flight.FlightInfo( + flight_info["schema"], + flight.FlightDescriptor.for_path(flight_info["command"].encode()), # Corrected line - removed extra brackets + [flight.FlightEndpoint( + ticket=flight_info["ticket"], + locations=[self._location] + )], + -1, + -1, + schema_metadata.serialize() + ) + yield flight_info_obj except Exception as e: logger.exception("Error in list_flights") raise flight.FlightUnavailableError(f"Failed to list flights: {str(e)}") + finally: + if current_conn: + self.return_connection_to_pool(context, current_conn) + + def _duckdb_to_arrow_type(self, duckdb_type: str) -> pa.DataType: + """Convert DuckDB type to Arrow type.""" + duckdb_type = duckdb_type.upper() + if 'VARCHAR' in duckdb_type or 'TEXT' in duckdb_type: + return pa.string() + elif 'INT' in duckdb_type: + return pa.int64() + elif 'DOUBLE' in duckdb_type or 'FLOAT' in duckdb_type: + return pa.float64() + elif 'BOOLEAN' in duckdb_type: + return pa.bool_() + elif 'DATE' in duckdb_type: + return pa.date32() + elif 'TIMESTAMP' in duckdb_type: + return pa.timestamp('us') + elif 'LIST' in duckdb_type: + element_type = duckdb_type.replace('LIST[', '').replace(']', '') + return pa.list_(self._duckdb_to_arrow_type(element_type)) + else: + logger.warning(f"Unknown DuckDB type: {duckdb_type}. Defaulting to string.") + return pa.string() def _arrow_to_duckdb_type(self, arrow_type): """Convert Arrow type to DuckDB type""" @@ -1002,56 +1019,47 @@ def _arrow_to_duckdb_type(self, arrow_type): elif pa.types.is_list(arrow_type): return f'{self._arrow_to_duckdb_type(arrow_type.value_type)}[]' else: - return 'VARCHAR' # Default to VARCHAR for unknown types + return 'VARCHAR' def do_exchange(self, context, descriptor, reader, writer): """Handle data exchange (PUT/INSERT operations)""" logger.debug("do_exchange called") + current_conn = None try: - # Get headers from middleware + current_conn = self._get_connection_from_context(context) middleware = context.get_middleware("auth") headers = middleware.headers if middleware else {} - - # Set up authenticated connection - if middleware and middleware.authorization: - auth_header = middleware.authorization - logger.info(f"Using authorization from middleware: {auth_header}") - if isinstance(auth_header, str): - if ':' in auth_header: - username, password = auth_header.split(':', 1) - user_pass_hash = hashlib.sha256((username + password).encode()).hexdigest() - else: - user_pass_hash = auth_header - - db_file = os.path.join(dbpath, f"{user_pass_hash}.db") - logger.info(f'Using database file: {db_file}') - self.conn = duckdb.connect(db_file) - - # Get operation type from headers operation = headers.get("airport-operation", [None])[0] logger.debug(f"Exchange operation: {operation}") - + if operation == "insert": # Get table path from headers table_path = headers.get("airport-flight-path", [None])[0] if not table_path: raise flight.FlightUnavailableError("No table path provided for insert operation") - + logger.debug(f"Inserting into table: {table_path}") - + + # Rewrite table_path to remove deltalake catalog prefix if present + if "deltalake." in table_path.lower(): + parts = table_path.split(".") + if len(parts) > 2 and parts[0].lower() == "deltalake": + table_path = ".".join(parts[1:]) + logger.debug(f"Rewritten table_path for insert: {table_path}") + try: # Read schema from reader schema = reader.schema logger.debug(f"Received schema: {schema}") - + # Create response schema early response_schema = pa.schema([('rows_inserted', pa.int64())]) writer.begin(response_schema) - + # Process data in batches total_rows = 0 batch_num = 0 - + # Read all batches try: while True: @@ -1059,32 +1067,33 @@ def do_exchange(self, context, descriptor, reader, writer): batch, metadata = reader.read_chunk() if batch is None: break - + batch_num += 1 logger.debug(f"Processing batch {batch_num} with {len(batch)} rows") - + # Create temporary table for this batch temp_table = pa.Table.from_batches([batch]) temp_name = f"temp_insert_table_{batch_num}" - + # Register and insert this batch - self.conn.register(temp_name, temp_table) - actual_schema = table_path.split('.')[0] if '.' in table_path else table_path - query = f"INSERT INTO {actual_schema}.{table_path} SELECT * FROM {temp_name}" + current_conn.register(temp_name, temp_table) + # actual_schema = table_path.split('.')[0] if '.' in table_path else table_path # Removed - using hardcoded schema + # query = f"INSERT INTO {actual_schema}.{table_path} SELECT * FROM temp_insert_table_1" # Old incorrect query + query = f"INSERT INTO test1.{table_path} SELECT * FROM temp_insert_table_1" # Hardcoded schema 'test1' for now - CORRECTED LINE logger.debug(f"Executing insert query: {query}") - self.conn.execute(query) - + current_conn.execute(query) + total_rows += len(batch) - + except StopIteration: logger.debug("Reached end of input stream") break except Exception as e: logger.exception(f"Error reading batch") raise - + logger.debug(f"Inserted total of {total_rows} rows") - + # Write response response_table = pa.Table.from_pylist( [{'rows_inserted': total_rows}], @@ -1092,32 +1101,32 @@ def do_exchange(self, context, descriptor, reader, writer): ) writer.write_table(response_table) writer.close() - + except Exception as e: logger.exception("Error during insert operation") raise flight.FlightUnavailableError(f"Insert operation failed: {str(e)}") - + else: raise flight.FlightUnavailableError(f"Unsupported operation: {operation}") - + except Exception as e: logger.exception("Error in do_exchange") raise flight.FlightUnavailableError(f"Exchange operation failed: {str(e)}") + finally: + if current_conn: + self.return_connection_to_pool(context, current_conn) server = DuckDBFlightServer() logger.info( f"Starting DuckDB Flight server on {flight_host}:{flight_port}") server.serve() - # Start Flask server in a daemon thread flask_thread = threading.Thread(target=run_flask, daemon=True) flask_thread.start() - # Run Flight server in main thread flight_thread = threading.Thread(target=run_flight_server, daemon=True) flight_thread.start() - # Keep main thread alive until signal try: while running: time.sleep(1)