malbox_plugin_sdk/runtime/
host.rs1use crate::context::Context;
14use crate::error::{Result, SdkError};
15use crate::meta::PluginMeta;
16use crate::plugin::HostPlugin;
17
18use malbox_plugin_transport::ipc::headers::{
19 EventKind, FLAG_IS_FINAL, ResponseKind, ResultHeader, TaskResponseHeader,
20};
21use malbox_plugin_transport::ipc::{
22 ActiveTaskRequest, DaemonEventSubscriber, EventHeader, PluginEventPublisher,
23 PluginWakeupListener, ResultPublisher, ResultSubscriber, TaskServer,
24};
25use malbox_plugin_transport::ipc::{
26 CallbackProgression, IpcService, Node, NodeBuilder, WaitSetBuilder,
27};
28use malbox_plugin_transport::messages::events::Event;
29use malbox_plugin_transport::traits::TransportEmitter;
30
31use std::collections::HashMap;
32use std::path::PathBuf;
33use std::sync::Arc;
34use std::sync::atomic::{AtomicBool, Ordering};
35
36use serde::{Deserialize, Serialize};
37use tracing::{debug, error, info, instrument};
38
39#[derive(Serialize, Deserialize)]
41pub struct TaskRequestPayload<'a> {
42 pub sample_path: &'a str,
44 pub config: Vec<(&'a str, &'a str)>,
46}
47
48#[derive(Deserialize)]
50struct TaskRequestPayloadOwned {
51 pub sample_path: String,
52 pub config: Vec<(String, String)>,
53}
54
55#[allow(dead_code)]
57pub struct HostRuntime<P> {
58 plugin: P,
59 meta: PluginMeta,
60 plugin_id: String,
61
62 _node: Node<IpcService>,
64
65 task_server: TaskServer,
67
68 daemon_event_sub: DaemonEventSubscriber,
70 plugin_event_pub: PluginEventPublisher,
71
72 result_pub: ResultPublisher,
74 result_subscriptions: Vec<(String, ResultSubscriber)>,
75
76 wakeup_listener: PluginWakeupListener,
78
79 shutdown: AtomicBool,
81}
82
83impl<P: HostPlugin> HostRuntime<P> {
84 pub fn new(plugin: P, meta: PluginMeta, subscribed_plugins: &[&str]) -> Result<Self> {
89 info!(plugin = %meta.name(), "Initializing host runtime");
90
91 let node = NodeBuilder::new()
92 .create::<IpcService>()
93 .map_err(|e| SdkError::Init(format!("Failed to create IPC node: {e}")))?;
94
95 let plugin_id = std::env::var("MALBOX_PLUGIN_ID")
96 .map_err(|_| SdkError::Init("MALBOX_PLUGIN_ID environment variable not set".into()))?;
97
98 let task_server = TaskServer::new(&node, &plugin_id)
99 .map_err(|e| SdkError::Init(format!("Failed to create task server: {e}")))?;
100
101 let daemon_event_sub = DaemonEventSubscriber::new(&node).map_err(|e| {
102 SdkError::Init(format!("Failed to create daemon event subscriber: {e}"))
103 })?;
104
105 let plugin_event_pub = PluginEventPublisher::new(&node, &plugin_id)
106 .map_err(|e| SdkError::Init(format!("Failed to create plugin event publisher: {e}")))?;
107
108 let result_pub = ResultPublisher::new(&node, &plugin_id)
109 .map_err(|e| SdkError::Init(format!("Failed to create result publisher: {e}")))?;
110
111 let wakeup_listener = PluginWakeupListener::new(&node, &plugin_id)
112 .map_err(|e| SdkError::Init(format!("Failed to create wakeup listener: {e}")))?;
113
114 let mut result_subscriptions = Vec::new();
116 for &sub_plugin_id in subscribed_plugins {
117 let sub = ResultSubscriber::new(&node, sub_plugin_id).map_err(|e| {
118 SdkError::Init(format!(
119 "Failed to subscribe to plugin '{sub_plugin_id}' results: {e}"
120 ))
121 })?;
122 result_subscriptions.push((sub_plugin_id.to_string(), sub));
123 }
124
125 info!(plugin = %meta.name(), "Host runtime initialized");
126
127 Ok(Self {
128 plugin,
129 meta,
130 plugin_id,
131 _node: node,
132 task_server,
133 daemon_event_sub,
134 plugin_event_pub,
135 result_pub,
136 result_subscriptions,
137 wakeup_listener,
138 shutdown: AtomicBool::new(false),
139 })
140 }
141
142 #[instrument(skip_all, fields(plugin = %self.meta.name()), err)]
144 pub fn run(&self) -> Result<()> {
145 self.plugin.on_start(HashMap::new())?;
146
147 let header = EventHeader {
149 event_type: EventKind::PluginStarted as u16,
150 source_plugin_id: 0,
151 associated_id: 0,
152 _reserved: [0; 6],
153 };
154 let _ = self.plugin_event_pub.emit_signal(&header);
155 info!(plugin = %self.meta.name(), "Plugin started, entering WaitSet event loop");
156
157 let waitset = WaitSetBuilder::new()
159 .create::<IpcService>()
160 .map_err(|e| SdkError::Init(format!("Failed to create WaitSet: {e}")))?;
161
162 let guard = waitset
163 .attach_notification(self.wakeup_listener.listener())
164 .map_err(|e| SdkError::Init(format!("Failed to attach to WaitSet: {e}")))?;
165
166 let _interval_guard = waitset
168 .attach_interval(core::time::Duration::from_millis(500))
169 .map_err(|e| SdkError::Init(format!("Failed to attach interval: {e}")))?;
170
171 let result = waitset.wait_and_process(|attachment_id| {
172 if self.shutdown.load(Ordering::Relaxed) {
173 return CallbackProgression::Stop;
174 }
175
176 if attachment_id.has_event_from(&guard) {
177 let _ = self.wakeup_listener.drain();
180 }
181
182 self.drain_tasks();
184 self.drain_daemon_events();
185 self.drain_result_subscriptions();
186
187 if self.shutdown.load(Ordering::Relaxed) {
188 return CallbackProgression::Stop;
189 }
190
191 CallbackProgression::Continue
192 });
193
194 if let Err(e) = result {
195 error!(plugin = %self.meta.name(), error = ?e, "WaitSet error");
196 }
197
198 if let Err(e) = self.plugin.on_stop() {
200 error!(plugin = %self.meta.name(), error = %e, "on_stop error");
201 }
202
203 let header = EventHeader {
204 event_type: EventKind::PluginStopped as u16,
205 source_plugin_id: 0,
206 associated_id: 0,
207 _reserved: [0; 6],
208 };
209 let _ = self.plugin_event_pub.emit_signal(&header);
210
211 info!(plugin = %self.meta.name(), "Plugin runtime exited");
212 Ok(())
213 }
214
215 fn drain_tasks(&self) {
216 while let Ok(Some(active_request)) = self.task_server.receive() {
217 self.handle_task(active_request);
218 }
219 }
220
221 fn drain_daemon_events(&self) {
222 while let Ok(Some(event)) = self.daemon_event_sub.try_recv() {
223 let kind = EventKind::try_from(event.header.event_type);
224 debug!(?kind, "Daemon event received");
225
226 if matches!(kind, Ok(EventKind::DaemonShutdown)) {
227 self.shutdown.store(true, Ordering::Relaxed);
228 return;
229 }
230
231 if let Ok(event) = Event::try_from(&event.header)
232 && let Err(e) = self.plugin.on_event(event)
233 {
234 error!(plugin = %self.meta.name(), error = %e, "on_event error");
235 }
236 }
237 }
238
239 fn drain_result_subscriptions(&self) {
240 for (source_id, subscriber) in &self.result_subscriptions {
241 while let Ok(Some(result)) = subscriber.try_recv() {
242 let result_name = extract_result_name(&result.header, &result.payload);
243
244 let event = Event::PluginResultAvailable {
248 source: source_id.clone(),
249 result_name,
250 };
251
252 if let Err(e) = self.plugin.on_event(event) {
253 error!(
254 plugin = %self.meta.name(),
255 source = %source_id,
256 error = %e,
257 "on_event (result available) error"
258 );
259 }
260 }
261 }
262 }
263
264 fn handle_task(&self, active_request: ActiveTaskRequest) {
265 let header = active_request.header();
266 let task_id = header.task_id;
267
268 let payload_bytes = active_request.payload();
269 let request: TaskRequestPayloadOwned = match postcard::from_bytes(payload_bytes) {
270 Ok(r) => r,
271 Err(e) => {
272 error!(task_id, error = %e, "Failed to deserialize task request");
273 let resp = TaskResponseHeader {
274 task_id,
275 kind: ResponseKind::Error as u8,
276 flags: FLAG_IS_FINAL,
277 ..Default::default()
278 };
279 let msg = format!("deserialization error: {e}");
280 let _ = active_request.send_response(&resp, msg.as_bytes());
281 return;
282 }
283 };
284
285 let config: HashMap<String, String> = request.config.into_iter().collect();
286
287 let (result_tx, mut result_rx) = tokio::sync::mpsc::channel(256);
288
289 let emitter: Arc<dyn TransportEmitter + Send + Sync> = Arc::new(());
290 let task_ctx = Context::new(
291 task_id,
292 PathBuf::from(&request.sample_path),
293 config,
294 emitter,
295 Some(result_tx),
296 #[cfg(feature = "guest")]
297 None,
298 );
299
300 let result = self.plugin.on_task(&task_ctx);
301 task_ctx.close_result_channel();
302
303 while let Ok(msg) = result_rx.try_recv() {
304 forward_result_to_ipc(&active_request, &msg);
305 }
306
307 match result {
308 Ok(()) => {
309 let resp = TaskResponseHeader {
310 task_id,
311 flags: FLAG_IS_FINAL,
312 ..Default::default()
313 };
314 let _ = active_request.send_response(&resp, &[]);
315
316 let ev = EventHeader {
317 event_type: EventKind::TaskCompleted as u16,
318 source_plugin_id: 0,
319 associated_id: task_id,
320 _reserved: [0; 6],
321 };
322 let _ = self.plugin_event_pub.emit_signal(&ev);
323 info!(plugin = %self.meta.name(), task_id, "Task completed");
324 }
325 Err(e) => {
326 let resp = TaskResponseHeader {
327 task_id,
328 kind: ResponseKind::Error as u8,
329 flags: FLAG_IS_FINAL,
330 ..Default::default()
331 };
332 let msg = e.to_string();
333 let _ = active_request.send_response(&resp, msg.as_bytes());
334
335 let ev = EventHeader {
336 event_type: EventKind::TaskFailed as u16,
337 source_plugin_id: 0,
338 associated_id: task_id,
339 _reserved: [0; 6],
340 };
341 let _ = self.plugin_event_pub.emit_signal(&ev);
342 error!(plugin = %self.meta.name(), task_id, error = %e, "Task failed");
343 }
344 }
345 }
346}
347
348use crate::context::message::{ResultKind as MsgResultKind, TaskResultMessage};
349use malbox_plugin_transport::ipc::headers::ResultFormat as IpcResultFormat;
350
351fn forward_result_to_ipc(active_request: &ActiveTaskRequest, msg: &TaskResultMessage) {
352 let (kind, ipc_format) = match msg.kind {
353 MsgResultKind::Progress => (ResponseKind::Progress, IpcResultFormat::Json),
354 MsgResultKind::Result => (
355 ResponseKind::Result,
356 match msg.format {
357 crate::context::message::ResultFormat::Json => IpcResultFormat::Json,
358 _ => IpcResultFormat::Bytes,
359 },
360 ),
361 MsgResultKind::ResultRef => return,
362 };
363
364 let name_bytes = msg.result_name.as_bytes();
365 let name_len = name_bytes.len().min(255) as u8;
366
367 let mut payload = Vec::with_capacity(name_len as usize + msg.data.len());
368 payload.extend_from_slice(&name_bytes[..name_len as usize]);
369 payload.extend_from_slice(&msg.data);
370
371 let header = TaskResponseHeader {
372 task_id: msg.task_id,
373 kind: kind as u8,
374 format: ipc_format as u8,
375 name_len,
376 ..Default::default()
377 };
378
379 if let Err(e) = active_request.send_response(&header, &payload) {
380 debug!(task_id = msg.task_id, error = %e, "failed to forward result via IPC");
381 }
382}
383
384fn extract_result_name(header: &ResultHeader, payload: &[u8]) -> String {
385 let name_len = header.name_len as usize;
386 if name_len > 0 && payload.len() >= name_len {
387 String::from_utf8_lossy(&payload[..name_len]).to_string()
388 } else {
389 String::new()
390 }
391}