diff --git a/CHANGELOG.md b/CHANGELOG.md index c0f8afac..3e1f6297 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Create data manager and spawn new process to keep the vts dictionary. [#191](https://github.com/greenbone/ospd/pull/191) - Update daemon start sequence. Run daemon.check before daemon.init now. [#197](https://github.com/greenbone/ospd/pull/197) - Improve get_vts cmd response, sending the vts piece by piece.[#201](https://github.com/greenbone/ospd/pull/201) +- Start the server before initialize to respond to the client.[#209](https://github.com/greenbone/ospd/pull/209) ### Fixed - Fix stop scan. Wait for the scan process to be stopped before delete it from the process table. [#204](https://github.com/greenbone/ospd/pull/204) diff --git a/ospd/main.py b/ospd/main.py index a52640db..7f8a4374 100644 --- a/ospd/main.py +++ b/ospd/main.py @@ -28,17 +28,20 @@ from functools import partial from typing import Type, Optional +from pathlib import Path -from ospd.misc import go_to_background, create_pid, remove_pidfile +from ospd.misc import go_to_background, create_pid from ospd.ospd import OSPDaemon from ospd.parser import create_parser, ParserType -from ospd.server import TlsServer, UnixSocketServer +from ospd.server import TlsServer, UnixSocketServer, BaseServer COPYRIGHT = """Copyright (C) 2014, 2015, 2018, 2019 Greenbone Networks GmbH License GPLv2+: GNU GPL version 2 or later This is free software: you are free to change and redistribute it. There is NO WARRANTY, to the extent permitted by law.""" +LOGGER = logging.getLogger(__name__) + def print_version(daemon: OSPDaemon, file=sys.stdout): """ Prints the server version and license information.""" @@ -104,6 +107,24 @@ def init_logging( os.dup2(syslog_fd, 2) +def exit_cleanup( + pidfile: str, server: BaseServer, _signum=None, _frame=None +) -> None: + """ Removes the pidfile before ending the daemon. """ + pidpath = Path(pidfile) + + if not pidpath.is_file(): + return + + with pidpath.open() as f: + if int(f.read()) == os.getpid(): + LOGGER.info("Shutting-down server ...") + server.close() + LOGGER.debug("Finishing daemon process") + pidpath.unlink() + sys.exit() + + def main( name: str, daemon_class: Type[OSPDaemon], @@ -153,14 +174,13 @@ def main( sys.exit() # Set signal handler and cleanup - atexit.register(remove_pidfile, pidfile=args.pid_file) - signal.signal(signal.SIGTERM, partial(remove_pidfile, args.pid_file)) + atexit.register(exit_cleanup, pidfile=args.pid_file, server=server) + signal.signal(signal.SIGTERM, partial(exit_cleanup, args.pid_file, server)) if not daemon.check(): return 1 - daemon.init() - - daemon.run(server) + daemon.init(server) + daemon.run() return 0 diff --git a/ospd/misc.py b/ospd/misc.py index 84229a49..24763896 100644 --- a/ospd/misc.py +++ b/ospd/misc.py @@ -118,17 +118,3 @@ def create_pid(pidfile: str) -> bool: return False return True - - -def remove_pidfile(pidfile: str, _signum=None, _frame=None) -> None: - """ Removes the pidfile before ending the daemon. """ - pidpath = Path(pidfile) - - if not pidpath.is_file(): - return - - with pidpath.open() as f: - if int(f.read()) == os.getpid(): - LOGGER.debug("Finishing daemon process") - pidpath.unlink() - sys.exit() diff --git a/ospd/ospd.py b/ospd/ospd.py index b1db88c8..d27c0a44 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -119,6 +119,8 @@ def __init__( self.server_version = None # Set by the subclass. + self.initialized = None # Set after initialization finished + self.scaninfo_store_time = kwargs.get('scaninfo_store_time') self.protocol_version = PROTOCOL_VERSION @@ -142,11 +144,13 @@ def __init__( else: self.vts_filter = VtsFilter() - def init(self) -> None: + def init(self, server: BaseServer) -> None: """ Should be overridden by a subclass if the initialization is costly. Will be called after check. """ + server.start(self.handle_client_stream) + self.initialized = True def set_command_attributes(self, name: str, attributes: Dict) -> None: """ Sets the xml attributes of a specified command. """ @@ -460,6 +464,15 @@ def handle_client_stream(self, stream) -> None: logger.debug("Empty client stream") return + if not self.initialized: + exception = OspdCommandError( + '%s is still starting' % self.daemon_info['name'], 'error' + ) + response = exception.as_xml() + stream.write(response) + stream.close() + return + response = None try: self.handle_command(data, stream) @@ -1177,12 +1190,10 @@ def check(self): """ Asserts to False. Should be implemented by subclass. """ raise NotImplementedError - def run(self, server: BaseServer) -> None: + def run(self) -> None: """ Starts the Daemon, handling commands until interrupted. """ - server.start(self.handle_client_stream) - try: while True: time.sleep(10) @@ -1191,9 +1202,6 @@ def run(self, server: BaseServer) -> None: self.wait_for_children() except KeyboardInterrupt: logger.info("Received Ctrl-C shutting-down ...") - finally: - logger.info("Shutting-down server ...") - server.close() def scheduler(self): """ Should be implemented by subclass in case of need