Skip to main content

malbox_plugin_sdk/runtime/
host.rs

1//! Host plugin runtime - WaitSet-based event loop over iceoryx2 IPC.
2//!
3//! The runtime multiplexes several IPC channels into a single event loop:
4//!
5//! - **Task dispatch** - request/response channel from the daemon.
6//! - **Daemon events** - pub/sub for lifecycle and system signals.
7//! - **Result chaining** - subscribes to other plugins' result channels
8//!   so a plugin can react to upstream outputs.
9//! - **WaitSet wakeup** - efficient notification-driven wakeup (no polling).
10//!
11//! Gated behind the `host` feature flag.
12
13use crate::context::Context;
14use crate::error::{Result, SdkError};
15use crate::meta::PluginMeta;
16use crate::plugin::HostPlugin;
17
18use malbox_plugin_transport::ipc::headers::{
19    EventKind, FLAG_IS_FINAL, ResponseKind, ResultHeader, TaskResponseHeader,
20};
21use malbox_plugin_transport::ipc::{
22    ActiveTaskRequest, DaemonEventSubscriber, EventHeader, PluginEventPublisher,
23    PluginWakeupListener, ResultPublisher, ResultSubscriber, TaskServer,
24};
25use malbox_plugin_transport::ipc::{
26    CallbackProgression, IpcService, Node, NodeBuilder, WaitSetBuilder,
27};
28use malbox_plugin_transport::messages::events::Event;
29use malbox_plugin_transport::traits::TransportEmitter;
30
31use std::collections::HashMap;
32use std::path::PathBuf;
33use std::sync::Arc;
34use std::sync::atomic::{AtomicBool, Ordering};
35
36use serde::{Deserialize, Serialize};
37use tracing::{debug, error, info, instrument};
38
39/// Task request payload serialized with postcard over IPC.
40#[derive(Serialize, Deserialize)]
41pub struct TaskRequestPayload<'a> {
42    /// Path to the sample file on the host filesystem.
43    pub sample_path: &'a str,
44    /// Key-value configuration pairs for this task.
45    pub config: Vec<(&'a str, &'a str)>,
46}
47
48/// Owned version for deserialization.
49#[derive(Deserialize)]
50struct TaskRequestPayloadOwned {
51    pub sample_path: String,
52    pub config: Vec<(String, String)>,
53}
54
55/// Host plugin runtime with WaitSet-driven event loop.
56#[allow(dead_code)]
57pub struct HostRuntime<P> {
58    plugin: P,
59    meta: PluginMeta,
60    plugin_id: String,
61
62    // iceoryx2 node - must outlive all ports created from it
63    _node: Node<IpcService>,
64
65    // Task dispatch (request/response)
66    task_server: TaskServer,
67
68    // Events
69    daemon_event_sub: DaemonEventSubscriber,
70    plugin_event_pub: PluginEventPublisher,
71
72    // Results (chaining)
73    result_pub: ResultPublisher,
74    result_subscriptions: Vec<(String, ResultSubscriber)>,
75
76    // WaitSet wakeup
77    wakeup_listener: PluginWakeupListener,
78
79    // Lifecycle
80    shutdown: AtomicBool,
81}
82
83impl<P: HostPlugin> HostRuntime<P> {
84    /// Create a new host plugin runtime.
85    ///
86    /// `subscribed_plugins` is the list of plugin IDs this plugin subscribes to
87    /// for event/result chaining (declared via macro).
88    pub fn new(plugin: P, meta: PluginMeta, subscribed_plugins: &[&str]) -> Result<Self> {
89        info!(plugin = %meta.name(), "Initializing host runtime");
90
91        let node = NodeBuilder::new()
92            .create::<IpcService>()
93            .map_err(|e| SdkError::Init(format!("Failed to create IPC node: {e}")))?;
94
95        let plugin_id = std::env::var("MALBOX_PLUGIN_ID")
96            .map_err(|_| SdkError::Init("MALBOX_PLUGIN_ID environment variable not set".into()))?;
97
98        let task_server = TaskServer::new(&node, &plugin_id)
99            .map_err(|e| SdkError::Init(format!("Failed to create task server: {e}")))?;
100
101        let daemon_event_sub = DaemonEventSubscriber::new(&node).map_err(|e| {
102            SdkError::Init(format!("Failed to create daemon event subscriber: {e}"))
103        })?;
104
105        let plugin_event_pub = PluginEventPublisher::new(&node, &plugin_id)
106            .map_err(|e| SdkError::Init(format!("Failed to create plugin event publisher: {e}")))?;
107
108        let result_pub = ResultPublisher::new(&node, &plugin_id)
109            .map_err(|e| SdkError::Init(format!("Failed to create result publisher: {e}")))?;
110
111        let wakeup_listener = PluginWakeupListener::new(&node, &plugin_id)
112            .map_err(|e| SdkError::Init(format!("Failed to create wakeup listener: {e}")))?;
113
114        // Subscribe to other plugins' result channels for chaining
115        let mut result_subscriptions = Vec::new();
116        for &sub_plugin_id in subscribed_plugins {
117            let sub = ResultSubscriber::new(&node, sub_plugin_id).map_err(|e| {
118                SdkError::Init(format!(
119                    "Failed to subscribe to plugin '{sub_plugin_id}' results: {e}"
120                ))
121            })?;
122            result_subscriptions.push((sub_plugin_id.to_string(), sub));
123        }
124
125        info!(plugin = %meta.name(), "Host runtime initialized");
126
127        Ok(Self {
128            plugin,
129            meta,
130            plugin_id,
131            _node: node,
132            task_server,
133            daemon_event_sub,
134            plugin_event_pub,
135            result_pub,
136            result_subscriptions,
137            wakeup_listener,
138            shutdown: AtomicBool::new(false),
139        })
140    }
141
142    /// Run the plugin event loop. Blocks until shutdown.
143    #[instrument(skip_all, fields(plugin = %self.meta.name()), err)]
144    pub fn run(&self) -> Result<()> {
145        self.plugin.on_start(HashMap::new())?;
146
147        // Emit PluginStarted
148        let header = EventHeader {
149            event_type: EventKind::PluginStarted as u16,
150            source_plugin_id: 0,
151            associated_id: 0,
152            _reserved: [0; 6],
153        };
154        let _ = self.plugin_event_pub.emit_signal(&header);
155        info!(plugin = %self.meta.name(), "Plugin started, entering WaitSet event loop");
156
157        // Build WaitSet
158        let waitset = WaitSetBuilder::new()
159            .create::<IpcService>()
160            .map_err(|e| SdkError::Init(format!("Failed to create WaitSet: {e}")))?;
161
162        let guard = waitset
163            .attach_notification(self.wakeup_listener.listener())
164            .map_err(|e| SdkError::Init(format!("Failed to attach to WaitSet: {e}")))?;
165
166        // Also attach an interval for periodic checks (shutdown, etc.)
167        let _interval_guard = waitset
168            .attach_interval(core::time::Duration::from_millis(500))
169            .map_err(|e| SdkError::Init(format!("Failed to attach interval: {e}")))?;
170
171        let result = waitset.wait_and_process(|attachment_id| {
172            if self.shutdown.load(Ordering::Relaxed) {
173                return CallbackProgression::Stop;
174            }
175
176            if attachment_id.has_event_from(&guard) {
177                // Drain notification kinds (we don't strictly need to distinguish,
178                // since we check all channels anyway)
179                let _ = self.wakeup_listener.drain();
180            }
181
182            // Always check all channels on any wakeup
183            self.drain_tasks();
184            self.drain_daemon_events();
185            self.drain_result_subscriptions();
186
187            if self.shutdown.load(Ordering::Relaxed) {
188                return CallbackProgression::Stop;
189            }
190
191            CallbackProgression::Continue
192        });
193
194        if let Err(e) = result {
195            error!(plugin = %self.meta.name(), error = ?e, "WaitSet error");
196        }
197
198        // Shutdown
199        if let Err(e) = self.plugin.on_stop() {
200            error!(plugin = %self.meta.name(), error = %e, "on_stop error");
201        }
202
203        let header = EventHeader {
204            event_type: EventKind::PluginStopped as u16,
205            source_plugin_id: 0,
206            associated_id: 0,
207            _reserved: [0; 6],
208        };
209        let _ = self.plugin_event_pub.emit_signal(&header);
210
211        info!(plugin = %self.meta.name(), "Plugin runtime exited");
212        Ok(())
213    }
214
215    fn drain_tasks(&self) {
216        while let Ok(Some(active_request)) = self.task_server.receive() {
217            self.handle_task(active_request);
218        }
219    }
220
221    fn drain_daemon_events(&self) {
222        while let Ok(Some(event)) = self.daemon_event_sub.try_recv() {
223            let kind = EventKind::try_from(event.header.event_type);
224            debug!(?kind, "Daemon event received");
225
226            if matches!(kind, Ok(EventKind::DaemonShutdown)) {
227                self.shutdown.store(true, Ordering::Relaxed);
228                return;
229            }
230
231            if let Ok(event) = Event::try_from(&event.header)
232                && let Err(e) = self.plugin.on_event(event)
233            {
234                error!(plugin = %self.meta.name(), error = %e, "on_event error");
235            }
236        }
237    }
238
239    fn drain_result_subscriptions(&self) {
240        for (source_id, subscriber) in &self.result_subscriptions {
241            while let Ok(Some(result)) = subscriber.try_recv() {
242                let result_name = extract_result_name(&result.header, &result.payload);
243
244                // Deliver as a lightweight PluginResultAvailable event.
245                // The actual data is accessible through the result subscriber
246                // at the macro-generated handler level (lazy access).
247                let event = Event::PluginResultAvailable {
248                    source: source_id.clone(),
249                    result_name,
250                };
251
252                if let Err(e) = self.plugin.on_event(event) {
253                    error!(
254                        plugin = %self.meta.name(),
255                        source = %source_id,
256                        error = %e,
257                        "on_event (result available) error"
258                    );
259                }
260            }
261        }
262    }
263
264    fn handle_task(&self, active_request: ActiveTaskRequest) {
265        let header = active_request.header();
266        let task_id = header.task_id;
267
268        let payload_bytes = active_request.payload();
269        let request: TaskRequestPayloadOwned = match postcard::from_bytes(payload_bytes) {
270            Ok(r) => r,
271            Err(e) => {
272                error!(task_id, error = %e, "Failed to deserialize task request");
273                let resp = TaskResponseHeader {
274                    task_id,
275                    kind: ResponseKind::Error as u8,
276                    flags: FLAG_IS_FINAL,
277                    ..Default::default()
278                };
279                let msg = format!("deserialization error: {e}");
280                let _ = active_request.send_response(&resp, msg.as_bytes());
281                return;
282            }
283        };
284
285        let config: HashMap<String, String> = request.config.into_iter().collect();
286
287        let (result_tx, mut result_rx) = tokio::sync::mpsc::channel(256);
288
289        let emitter: Arc<dyn TransportEmitter + Send + Sync> = Arc::new(());
290        let task_ctx = Context::new(
291            task_id,
292            PathBuf::from(&request.sample_path),
293            config,
294            emitter,
295            Some(result_tx),
296            #[cfg(feature = "guest")]
297            None,
298        );
299
300        let result = self.plugin.on_task(&task_ctx);
301        task_ctx.close_result_channel();
302
303        while let Ok(msg) = result_rx.try_recv() {
304            forward_result_to_ipc(&active_request, &msg);
305        }
306
307        match result {
308            Ok(()) => {
309                let resp = TaskResponseHeader {
310                    task_id,
311                    flags: FLAG_IS_FINAL,
312                    ..Default::default()
313                };
314                let _ = active_request.send_response(&resp, &[]);
315
316                let ev = EventHeader {
317                    event_type: EventKind::TaskCompleted as u16,
318                    source_plugin_id: 0,
319                    associated_id: task_id,
320                    _reserved: [0; 6],
321                };
322                let _ = self.plugin_event_pub.emit_signal(&ev);
323                info!(plugin = %self.meta.name(), task_id, "Task completed");
324            }
325            Err(e) => {
326                let resp = TaskResponseHeader {
327                    task_id,
328                    kind: ResponseKind::Error as u8,
329                    flags: FLAG_IS_FINAL,
330                    ..Default::default()
331                };
332                let msg = e.to_string();
333                let _ = active_request.send_response(&resp, msg.as_bytes());
334
335                let ev = EventHeader {
336                    event_type: EventKind::TaskFailed as u16,
337                    source_plugin_id: 0,
338                    associated_id: task_id,
339                    _reserved: [0; 6],
340                };
341                let _ = self.plugin_event_pub.emit_signal(&ev);
342                error!(plugin = %self.meta.name(), task_id, error = %e, "Task failed");
343            }
344        }
345    }
346}
347
348use crate::context::message::{ResultKind as MsgResultKind, TaskResultMessage};
349use malbox_plugin_transport::ipc::headers::ResultFormat as IpcResultFormat;
350
351fn forward_result_to_ipc(active_request: &ActiveTaskRequest, msg: &TaskResultMessage) {
352    let (kind, ipc_format) = match msg.kind {
353        MsgResultKind::Progress => (ResponseKind::Progress, IpcResultFormat::Json),
354        MsgResultKind::Result => (
355            ResponseKind::Result,
356            match msg.format {
357                crate::context::message::ResultFormat::Json => IpcResultFormat::Json,
358                _ => IpcResultFormat::Bytes,
359            },
360        ),
361        MsgResultKind::ResultRef => return,
362    };
363
364    let name_bytes = msg.result_name.as_bytes();
365    let name_len = name_bytes.len().min(255) as u8;
366
367    let mut payload = Vec::with_capacity(name_len as usize + msg.data.len());
368    payload.extend_from_slice(&name_bytes[..name_len as usize]);
369    payload.extend_from_slice(&msg.data);
370
371    let header = TaskResponseHeader {
372        task_id: msg.task_id,
373        kind: kind as u8,
374        format: ipc_format as u8,
375        name_len,
376        ..Default::default()
377    };
378
379    if let Err(e) = active_request.send_response(&header, &payload) {
380        debug!(task_id = msg.task_id, error = %e, "failed to forward result via IPC");
381    }
382}
383
384fn extract_result_name(header: &ResultHeader, payload: &[u8]) -> String {
385    let name_len = header.name_len as usize;
386    if name_len > 0 && payload.len() >= name_len {
387        String::from_utf8_lossy(&payload[..name_len]).to_string()
388    } else {
389        String::new()
390    }
391}