Skip to main content

malbox_plugin_sdk/
context.rs

1//! Plugin execution context.
2//!
3//! [`Context`] is the main handle plugins use during task execution.
4//! It provides access to task metadata via [`TaskInfo`], result pushing
5//! via [`ResultSink`], and progress reporting back to the daemon.
6
7pub mod message;
8mod result_sink;
9
10pub use result_sink::ResultSink;
11
12use crate::error::{Result, SdkError};
13use message::{ResultFormat, ResultKind, TaskResultMessage};
14
15use malbox_plugin_transport::messages::events::Event;
16use malbox_plugin_transport::traits::TransportEmitter;
17use serde::Serialize;
18use std::collections::{HashMap, HashSet};
19use std::path::{Path, PathBuf};
20use std::sync::{Arc, Mutex};
21use tracing::{info, warn};
22
23#[cfg(feature = "guest")]
24use crate::stash::ResultStash;
25
26#[derive(Serialize)]
27struct ProgressPayload<'a> {
28    progress: f64,
29    message: &'a str,
30}
31
32/// Sender half of the result channel between a plugin and its runtime.
33///
34/// The runtime creates one of these per task and hands it to [`Context`].
35/// During `on_task` the context uses it to stream [`TaskResultMessage`]s
36/// back to the daemon. It is `None` when no task is active.
37pub type ResultSender = tokio::sync::mpsc::Sender<TaskResultMessage>;
38
39pub(super) struct ContextInner {
40    pub task_id: i32,
41    pub sample_path: PathBuf,
42    pub config: HashMap<String, String>,
43    pub emitter: Arc<dyn TransportEmitter + Send + Sync>,
44    pub result_tx: Mutex<Option<ResultSender>>,
45    #[cfg(feature = "guest")]
46    pub stash: Option<Arc<ResultStash>>,
47    pub claimed_paths: Mutex<HashSet<PathBuf>>,
48}
49
50fn lock_or_err<T>(mutex: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>> {
51    mutex
52        .lock()
53        .map_err(|_| SdkError::Channel("internal mutex poisoned".into()))
54}
55
56/// The main handle plugins interact with during task execution.
57///
58/// Provides access to task metadata, result pushing, and progress reporting.
59/// `Context` is cheaply cloneable (it wraps an `Arc`) so you can share it
60/// with background threads.
61#[derive(Clone)]
62pub struct Context {
63    inner: Arc<ContextInner>,
64}
65
66impl Context {
67    /// Create a new context with task data, a transport emitter, and an
68    /// optional result channel.
69    ///
70    /// Used by the runtime; downstream test code should use
71    /// [`Context::test_new`](crate::testkit) under the `testkit` feature.
72    pub(crate) fn new(
73        task_id: i32,
74        sample_path: PathBuf,
75        config: HashMap<String, String>,
76        emitter: Arc<dyn TransportEmitter + Send + Sync>,
77        result_tx: Option<ResultSender>,
78        #[cfg(feature = "guest")] stash: Option<Arc<ResultStash>>,
79    ) -> Self {
80        Self {
81            inner: Arc::new(ContextInner {
82                task_id,
83                sample_path,
84                config,
85                emitter,
86                result_tx: Mutex::new(result_tx),
87                #[cfg(feature = "guest")]
88                stash,
89                claimed_paths: Mutex::new(HashSet::new()),
90            }),
91        }
92    }
93
94    /// Get a read-only view of the current task's metadata.
95    pub fn task(&self) -> TaskInfo<'_> {
96        TaskInfo { inner: &self.inner }
97    }
98
99    /// Get a [`ResultSink`] for pushing results back to the daemon.
100    pub fn results(&self) -> ResultSink<'_> {
101        ResultSink { inner: &self.inner }
102    }
103
104    /// Report execution progress (0.0 to 1.0) with a status message.
105    ///
106    /// Progress updates are streamed to the daemon in real-time and can be
107    /// shown in the UI. The `progress` value is clamped to `[0.0, 1.0]`.
108    ///
109    /// If a result channel is available, a PROGRESS-kind message is sent
110    /// over it. Otherwise, just logs the progress.
111    ///
112    /// # Panics
113    ///
114    /// Uses `blocking_send` internally. Must be called from a blocking
115    /// context (e.g. `spawn_blocking`), **not** from within an async task.
116    pub fn progress(&self, pct: f64, message: &str) -> Result<()> {
117        let clamped = pct.clamp(0.0, 1.0);
118        info!(kind = "progress", progress = clamped, %message, "Plugin progress");
119
120        let tx = {
121            let guard = lock_or_err(&self.inner.result_tx)?;
122            guard.clone()
123        };
124
125        if let Some(tx) = tx {
126            let progress_data = serde_json::to_vec(&ProgressPayload {
127                progress: clamped,
128                message,
129            })?;
130
131            let msg = TaskResultMessage {
132                task_id: self.inner.task_id,
133                result_name: String::new(),
134                data: progress_data,
135                format: ResultFormat::Json,
136                is_final: false,
137                kind: ResultKind::Progress,
138                stash_handle: String::new(),
139                stash_format: ResultFormat::Unspecified,
140                stash_size: 0,
141            };
142
143            tx.blocking_send(msg)
144                .map_err(|e| SdkError::Channel(format!("progress: {e}")))?;
145        }
146
147        Ok(())
148    }
149
150    /// Send a transport-level event to the daemon.
151    ///
152    /// Most plugins won't need this - it's an escape hatch for cases
153    /// where you need to signal something outside the normal result flow.
154    pub fn emit_event(&self, event: Event) -> Result<()> {
155        self.inner.emitter.emit(event).map_err(SdkError::Transport)
156    }
157
158    /// Emit a warning via `tracing` at the WARN level.
159    ///
160    /// For guest plugins, warnings are captured by the [`LogBus`](crate::log::LogBus)
161    /// and streamed to the daemon. They appear in the log stream, not in the
162    /// task report - use [`ResultSink::push`] for structured results.
163    pub fn warn(&self, message: &str) -> Result<()> {
164        warn!(kind = "warning", %message, "Plugin warning");
165        Ok(())
166    }
167
168    /// Mark a file path as already handled so auto-collection skips it.
169    ///
170    /// Call this when your plugin reads a file from the artifacts directory
171    /// and sends its own processed version as a result. Without this, the
172    /// auto-collector would send the raw file as a duplicate.
173    pub fn mark_collected(&self, path: impl AsRef<Path>) {
174        if let Ok(canonical) = std::fs::canonicalize(path.as_ref())
175            && let Ok(mut guard) = self.inner.claimed_paths.lock()
176        {
177            guard.insert(canonical);
178        }
179    }
180
181    /// Close the result channel, preventing further pushes.
182    ///
183    /// Called by the runtime after `on_stop` to signal end-of-task.
184    #[allow(dead_code)]
185    pub(crate) fn close_result_channel(&self) {
186        if let Ok(mut guard) = self.inner.result_tx.lock() {
187            *guard = None;
188        }
189    }
190
191    /// Return a snapshot of all claimed file paths.
192    ///
193    /// Used internally by auto-collection to determine which files
194    /// have already been explicitly sent or marked.
195    #[cfg(feature = "guest")]
196    pub(crate) fn claimed_paths(&self) -> HashSet<PathBuf> {
197        self.inner
198            .claimed_paths
199            .lock()
200            .unwrap_or_else(|e| e.into_inner())
201            .clone()
202    }
203}
204
205/// Read-only view of task metadata: the task ID, sample file path, and
206/// key-value configuration submitted with the task.
207pub struct TaskInfo<'a> {
208    inner: &'a ContextInner,
209}
210
211impl TaskInfo<'_> {
212    /// The task's numeric identifier, unique within a single daemon run.
213    pub fn id(&self) -> i32 {
214        self.inner.task_id
215    }
216
217    /// Path to the sample file on disk.
218    pub fn sample_path(&self) -> &Path {
219        &self.inner.sample_path
220    }
221
222    /// Read the entire sample file into memory. Be mindful of large samples.
223    pub fn sample_bytes(&self) -> Result<Vec<u8>> {
224        std::fs::read(&self.inner.sample_path).map_err(SdkError::Io)
225    }
226
227    /// The full key-value configuration map submitted with the task.
228    pub fn config(&self) -> &HashMap<String, String> {
229        &self.inner.config
230    }
231
232    /// Look up a single value from the task configuration. Returns `None`
233    /// if the key is not present.
234    pub fn config_value(&self, key: &str) -> Option<&str> {
235        self.inner.config.get(key).map(|s| s.as_str())
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use std::sync::Arc;
243
244    fn noop_emitter() -> Arc<dyn TransportEmitter + Send + Sync> {
245        Arc::new(())
246    }
247
248    #[test]
249    fn context_is_clone() {
250        let ctx = Context::test_new(noop_emitter());
251        let ctx2 = ctx.clone();
252        assert_eq!(ctx.task().id(), ctx2.task().id());
253    }
254
255    #[test]
256    fn context_progress_clamps_values() {
257        let ctx = Context::test_new(noop_emitter());
258
259        assert!(ctx.progress(-0.5, "negative").is_ok());
260        assert!(ctx.progress(1.5, "over").is_ok());
261        assert!(ctx.progress(0.5, "normal").is_ok());
262    }
263
264    #[test]
265    fn context_warn_succeeds() {
266        let ctx = Context::test_new(noop_emitter());
267        assert!(ctx.warn("test warning").is_ok());
268    }
269
270    #[test]
271    fn context_emit_event_succeeds_with_noop() {
272        let ctx = Context::test_new(noop_emitter());
273        let result = ctx.emit_event(Event::PluginStarted { plugin_id: 1 });
274        assert!(result.is_ok());
275    }
276
277    #[test]
278    fn task_info_accessors_work() {
279        let mut config = HashMap::new();
280        config.insert("k".to_string(), "v".to_string());
281        let ctx = Context::new(
282            7,
283            PathBuf::from("/tmp/s.bin"),
284            config,
285            noop_emitter(),
286            None,
287            #[cfg(feature = "guest")]
288            None,
289        );
290        assert_eq!(ctx.task().id(), 7);
291        assert_eq!(ctx.task().sample_path(), Path::new("/tmp/s.bin"));
292        assert_eq!(ctx.task().config().get("k"), Some(&"v".to_string()));
293        assert_eq!(ctx.task().config_value("k"), Some("v"));
294        assert_eq!(ctx.task().config_value("missing"), None);
295    }
296}