Skip to main content

mecab_ko_core/viterbi/
mod.rs

1//! Viterbi 알고리즘
2//!
3//! 최적 형태소 분석 경로를 찾는 Viterbi 알고리즘을 구현합니다.
4//!
5//! # 개요
6//!
7//! Viterbi 알고리즘은 Lattice에서 최소 비용 경로를 찾는 동적 프로그래밍 알고리즘입니다.
8//!
9//! ```text
10//! 총 비용 = Σ(단어 비용) + Σ(연접 비용) + Σ(띄어쓰기 패널티)
11//! ```
12//!
13//! # 알고리즘
14//!
15//! 1. **Forward Pass**: BOS에서 시작하여 각 노드까지의 최소 비용 계산
16//! 2. **Backward Pass**: EOS에서 BOS까지 역추적하여 최적 경로 추출
17//!
18//! # 한국어 특화
19//!
20//! - `left-space-penalty-factor`: 띄어쓰기 후 특정 품사 시작 시 페널티 부여
21//! - 조사(JK*), 어미(E*) 등이 띄어쓰기 직후 시작하면 높은 페널티
22//!
23//! # Example
24//!
25//! ```rust,no_run
26//! use mecab_ko_core::viterbi::{ViterbiSearcher, SpacePenalty};
27//! use mecab_ko_core::lattice::Lattice;
28//!
29//! let mut lattice = Lattice::new("아버지가방에");
30//! // ... 노드 추가 후 검색 ...
31//!
32//! let searcher = ViterbiSearcher::new()
33//!     .with_space_penalty(SpacePenalty::korean_default());
34//! ```
35
36use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
37use std::cmp::Ordering;
38use std::collections::BinaryHeap;
39use std::rc::Rc;
40
41// SIMD 최적화 모듈
42#[cfg(feature = "simd")]
43pub mod simd;
44
45#[cfg(feature = "simd")]
46pub use simd::{simd_forward_pass_position, simd_update_node_cost};
47
48/// 여러 값의 포화 덧셈 (체인)
49///
50/// 오버플로우 방지를 위해 포화 연산 사용
51#[inline(always)]
52const fn saturating_add_chain(a: i32, b: i32, c: i32, d: i32) -> i32 {
53    a.saturating_add(b).saturating_add(c).saturating_add(d)
54}
55
56/// Out-of-range matrix lookup fallback cost.
57///
58/// When `DenseMatrix::get()` returns `i32::MAX` (ID pair outside the matrix
59/// dimensions), this value is substituted so that UNKNOWN nodes can still
60/// participate in the Viterbi search.  In-range cells store `i16` costs
61/// (max 32 767), so `i32::MAX` unambiguously signals an OOB lookup.
62pub(crate) const DEFAULT_OOB_CONNECTION_COST: i32 = 10_000;
63
64#[inline(always)]
65pub(crate) const fn clamp_oob_cost(raw: i32) -> i32 {
66    if raw == i32::MAX {
67        DEFAULT_OOB_CONNECTION_COST
68    } else {
69        raw
70    }
71}
72
73/// 연접 비용 조회 인터페이스
74///
75/// 두 형태소 간의 연결 비용을 반환합니다.
76/// 이 비용은 matrix.def에서 학습된 값입니다.
77pub trait ConnectionCost {
78    /// 두 문맥 ID 간의 연접 비용 반환
79    ///
80    /// # Arguments
81    ///
82    /// * `right_id` - 이전 노드의 우문맥 ID
83    /// * `left_id` - 현재 노드의 좌문맥 ID
84    ///
85    /// # Returns
86    ///
87    /// 연접 비용 (낮을수록 좋음)
88    fn cost(&self, right_id: u16, left_id: u16) -> i32;
89}
90
91/// 더미 연접 비용 (테스트용)
92///
93/// 모든 연접에 대해 0을 반환합니다.
94#[derive(Debug, Clone, Default)]
95pub struct ZeroConnectionCost;
96
97impl ConnectionCost for ZeroConnectionCost {
98    #[inline(always)]
99    fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
100        0
101    }
102}
103
104/// 고정 연접 비용 (테스트용)
105#[derive(Debug, Clone)]
106pub struct FixedConnectionCost {
107    /// 기본 비용
108    pub default_cost: i32,
109}
110
111impl FixedConnectionCost {
112    /// 새 고정 비용 생성
113    #[must_use]
114    pub const fn new(cost: i32) -> Self {
115        Self { default_cost: cost }
116    }
117}
118
119impl ConnectionCost for FixedConnectionCost {
120    #[inline(always)]
121    fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
122        self.default_cost
123    }
124}
125
126/// mecab-ko-dict의 `Matrix` trait에 대한 `ConnectionCost` 구현
127///
128/// 사전 모듈의 연접 비용 행렬을 Viterbi 알고리즘에서 직접 사용할 수 있습니다.
129impl<T: mecab_ko_dict::Matrix> ConnectionCost for T {
130    #[inline(always)]
131    fn cost(&self, right_id: u16, left_id: u16) -> i32 {
132        self.get(right_id, left_id)
133    }
134}
135
136/// 띄어쓰기 패널티 설정
137///
138/// mecab-ko의 `left-space-penalty-factor` 기능을 구현합니다.
139/// 띄어쓰기 직후에 특정 품사가 오면 페널티를 부여하여 오분석을 방지합니다.
140///
141/// # Example
142///
143/// ```rust
144/// use mecab_ko_core::viterbi::SpacePenalty;
145///
146/// // mecab-ko 기본 설정
147/// let penalty = SpacePenalty::korean_default();
148///
149/// // dicrc 형식에서 생성
150/// let penalty = SpacePenalty::from_dicrc("1785,6000;1786,6000");
151/// ```
152#[derive(Debug, Clone, Default)]
153pub struct SpacePenalty {
154    /// 페널티를 적용할 품사 ID 목록과 페널티 값
155    /// `(left_id, penalty)`
156    penalties: Vec<(u16, i32)>,
157}
158
159impl SpacePenalty {
160    /// 빈 페널티 설정 생성
161    #[must_use]
162    pub fn new() -> Self {
163        Self::default()
164    }
165
166    /// 한국어 기본 페널티 설정
167    ///
168    /// 조사(JK*)와 어미(E*)가 띄어쓰기 직후 나타나면 높은 페널티를 부여합니다.
169    /// 이는 "아버지가방에" → "아버지가 방에"로 분석하는 데 도움이 됩니다.
170    #[must_use]
171    pub fn korean_default() -> Self {
172        // mecab-ko-dic의 left-id 기준 (실제 값은 사전에 따라 다름)
173        // 여기서는 대표적인 조사/어미 ID 범위를 사용
174        // Build ranges in sorted order so binary_search in get() works correctly.
175
176        // 어미 계열 (EP, EF, EC, ETN, ETM): 1700~1759
177        // 조사 계열 (JKS, JKC, JKG, JKO, JKB, JKV, JKQ, JX, JC): 1780~1809
178        let mut penalties: Vec<(u16, i32)> = (1700u16..1760)
179            .chain(1780..1810)
180            .map(|id| (id, 6000))
181            .collect();
182
183        // Ensure sorted for binary_search
184        penalties.sort_unstable_by_key(|&(id, _)| id);
185        Self { penalties }
186    }
187
188    /// mecab-ko의 dicrc 설정에서 생성
189    ///
190    /// # Format
191    ///
192    /// `left_id,penalty;left_id,penalty;...`
193    ///
194    /// # Example
195    ///
196    /// ```rust
197    /// use mecab_ko_core::viterbi::SpacePenalty;
198    ///
199    /// let penalty = SpacePenalty::from_dicrc("1785,6000;1786,6000;1787,5000");
200    /// assert_eq!(penalty.get(1785), 6000);
201    /// assert_eq!(penalty.get(1786), 6000);
202    /// assert_eq!(penalty.get(9999), 0);  // 미등록
203    /// ```
204    #[must_use]
205    pub fn from_dicrc(config: &str) -> Self {
206        let mut penalties = Vec::new();
207
208        for part in config.split(';') {
209            let parts: Vec<&str> = part.trim().split(',').collect();
210            if parts.len() == 2 {
211                if let (Ok(id), Ok(penalty)) = (
212                    parts[0].trim().parse::<u16>(),
213                    parts[1].trim().parse::<i32>(),
214                ) {
215                    penalties.push((id, penalty));
216                }
217            }
218        }
219
220        // Keep sorted for binary search in get()
221        penalties.sort_unstable_by_key(|&(id, _)| id);
222        Self { penalties }
223    }
224
225    /// 페널티 추가
226    pub fn add(&mut self, left_id: u16, penalty: i32) {
227        // Insert in sorted position for binary search correctness
228        let pos = self.penalties.partition_point(|&(id, _)| id < left_id);
229        self.penalties.insert(pos, (left_id, penalty));
230    }
231
232    /// 특정 품사 ID에 대한 페널티 조회
233    ///
234    /// # Returns
235    ///
236    /// 해당 ID에 설정된 페널티, 없으면 0
237    #[must_use]
238    #[inline]
239    pub fn get(&self, left_id: u16) -> i32 {
240        // Binary search on sorted penalties for O(log n) instead of O(n)
241        self.penalties
242            .binary_search_by_key(&left_id, |&(id, _)| id)
243            .map_or(0, |idx| self.penalties[idx].1)
244    }
245
246    /// 페널티가 설정되어 있는지 확인
247    #[must_use]
248    #[inline]
249    pub fn is_empty(&self) -> bool {
250        self.penalties.is_empty()
251    }
252
253    /// 설정된 페널티 개수
254    #[must_use]
255    #[inline]
256    pub fn len(&self) -> usize {
257        self.penalties.len()
258    }
259}
260
261/// Viterbi 탐색기
262///
263/// Lattice에서 최적 경로를 찾는 Viterbi 알고리즘을 구현합니다.
264#[derive(Debug, Clone)]
265pub struct ViterbiSearcher {
266    /// 띄어쓰기 패널티 설정
267    pub space_penalty: SpacePenalty,
268}
269
270impl Default for ViterbiSearcher {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276impl ViterbiSearcher {
277    /// 새 탐색기 생성
278    #[must_use]
279    pub fn new() -> Self {
280        Self {
281            space_penalty: SpacePenalty::default(),
282        }
283    }
284
285    /// 띄어쓰기 패널티 설정
286    #[must_use]
287    pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
288        self.space_penalty = penalty;
289        self
290    }
291
292    /// 최적 경로 탐색 (Forward-Backward)
293    ///
294    /// # Arguments
295    ///
296    /// * `lattice` - 노드가 추가된 Lattice
297    /// * `conn_cost` - 연접 비용 조회 인터페이스
298    ///
299    /// # Returns
300    ///
301    /// 최적 경로의 노드 ID 목록 (BOS, EOS 제외)
302    ///
303    /// # Example
304    ///
305    /// ```rust,no_run
306    /// # use mecab_ko_core::viterbi::{ViterbiSearcher, SpacePenalty};
307    /// # use mecab_ko_core::lattice::Lattice;
308    /// # let searcher = ViterbiSearcher::new();
309    /// # let conn_cost = mecab_ko_dict::matrix::DenseMatrix::new(1, 1, 0);
310    /// # let mut lattice = Lattice::new("test");
311    /// let path = searcher.search(&mut lattice, &conn_cost);
312    /// for node_id in path {
313    ///     let node = lattice.node(node_id).unwrap();
314    ///     println!("{}: {}", node.surface, node.word_cost);
315    /// }
316    /// ```
317    pub fn search<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) -> Vec<NodeId> {
318        // Forward pass
319        self.forward_pass(lattice, conn_cost);
320
321        // Backward pass
322        Self::backward_pass(lattice)
323    }
324
325    /// Forward Pass: 각 노드의 최소 비용 계산
326    ///
327    /// BOS에서 시작하여 각 위치의 노드들에 대해 최소 비용을 계산합니다.
328    fn forward_pass<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) {
329        let char_len = lattice.char_len();
330
331        // Reusable scratch buffers to avoid per-position Vec allocations.
332        // We collect (node_id) for the starting nodes and (id, total_cost, right_id)
333        // for ending nodes into these, clearing between positions.
334        let mut starting_ids: Vec<NodeId> = Vec::new();
335        let mut ending_nodes: Vec<(NodeId, i32, u16)> = Vec::new();
336
337        // 위치 0부터 끝까지 순회
338        for pos in 0..=char_len {
339            // Collect starting node IDs (need ownership before mutating lattice)
340            starting_ids.clear();
341            starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
342
343            // Collect ending node data once per position, reused for every
344            // starting node at this position.
345            ending_nodes.clear();
346            ending_nodes.extend(
347                lattice
348                    .nodes_ending_at(pos)
349                    .map(|n| (n.id, n.total_cost, n.right_id)),
350            );
351
352            for &node_id in &starting_ids {
353                self.update_node_cost_with_endings(lattice, conn_cost, node_id, &ending_nodes);
354            }
355        }
356    }
357
358    /// 단일 노드의 최소 비용 계산 및 업데이트 (사전 수집된 `ending_nodes` 사용)
359    ///
360    /// Hot path: 성능 최적화를 위해 인라인 처리
361    /// SIMD 최적화: 8개 이상의 이전 노드가 있으면 SIMD 배치 처리 사용
362    #[inline]
363    fn update_node_cost_with_endings<C: ConnectionCost>(
364        &self,
365        lattice: &mut Lattice,
366        conn_cost: &C,
367        node_id: NodeId,
368        ending_nodes: &[(NodeId, i32, u16)],
369    ) {
370        // SIMD 최적화: 8개 이상의 이전 노드가 있으면 SIMD 사용
371        #[cfg(feature = "simd")]
372        if ending_nodes.len() >= 8 {
373            let (best_cost, best_prev) = simd::simd_update_node_cost(
374                lattice,
375                conn_cost,
376                node_id,
377                ending_nodes,
378                &self.space_penalty,
379            );
380            if let Some(node) = lattice.node_mut(node_id) {
381                node.total_cost = best_cost;
382                node.prev_node_id = best_prev;
383            }
384            return;
385        }
386
387        // 현재 노드 정보 추출
388        let (left_id, word_cost, has_space) = {
389            let Some(node) = lattice.node(node_id) else {
390                return;
391            };
392            (node.left_id, node.word_cost, node.has_space_before)
393        };
394
395        // 띄어쓰기 패널티는 left_id에 대해 한 번만 조회
396        let space_penalty = if has_space {
397            self.space_penalty.get(left_id)
398        } else {
399            0
400        };
401
402        let mut best_cost = i32::MAX;
403        let mut best_prev = INVALID_NODE_ID;
404
405        for &(prev_id, prev_cost, prev_right_id) in ending_nodes {
406            // 이전 노드까지의 비용이 무한대면 스킵
407            if prev_cost == i32::MAX {
408                continue;
409            }
410
411            // 연접 비용 계산
412            let connection = clamp_oob_cost(conn_cost.cost(prev_right_id, left_id));
413
414            // 총 비용 = 이전 비용 + 연접 비용 + 단어 비용 + 띄어쓰기 패널티
415            let total = saturating_add_chain(prev_cost, connection, word_cost, space_penalty);
416
417            if total < best_cost {
418                best_cost = total;
419                best_prev = prev_id;
420            }
421        }
422
423        // 노드 업데이트
424        if let Some(node) = lattice.node_mut(node_id) {
425            node.total_cost = best_cost;
426            node.prev_node_id = best_prev;
427        }
428    }
429
430    /// 단일 노드의 최소 비용 계산 및 업데이트 (레거시, 테스트용으로 유지)
431    #[cfg(test)]
432    #[allow(dead_code)]
433    fn update_node_cost<C: ConnectionCost>(
434        &self,
435        lattice: &mut Lattice,
436        conn_cost: &C,
437        node_id: NodeId,
438        pos: usize,
439    ) {
440        // 현재 노드 정보 추출
441        let (left_id, word_cost, has_space) = {
442            let Some(node) = lattice.node(node_id) else {
443                return;
444            };
445            (node.left_id, node.word_cost, node.has_space_before)
446        };
447
448        // 이 노드로 연결될 수 있는 이전 노드들 (pos에서 끝나는 노드들)
449        let ending_nodes: Vec<(NodeId, i32, u16)> = lattice
450            .nodes_ending_at(pos)
451            .map(|n| (n.id, n.total_cost, n.right_id))
452            .collect();
453
454        let mut best_cost = i32::MAX;
455        let mut best_prev = INVALID_NODE_ID;
456
457        for (prev_id, prev_cost, prev_right_id) in ending_nodes {
458            if prev_cost == i32::MAX {
459                continue;
460            }
461
462            let connection = clamp_oob_cost(conn_cost.cost(prev_right_id, left_id));
463
464            let space_penalty = if has_space {
465                self.space_penalty.get(left_id)
466            } else {
467                0
468            };
469
470            let total = prev_cost
471                .saturating_add(connection)
472                .saturating_add(word_cost)
473                .saturating_add(space_penalty);
474
475            if total < best_cost {
476                best_cost = total;
477                best_prev = prev_id;
478            }
479        }
480
481        // 노드 업데이트
482        if let Some(node) = lattice.node_mut(node_id) {
483            node.total_cost = best_cost;
484            node.prev_node_id = best_prev;
485        }
486    }
487
488    /// Backward Pass: EOS에서 BOS까지 역추적
489    ///
490    /// 최적 경로의 노드 ID 목록을 반환합니다 (BOS, EOS 제외).
491    fn backward_pass(lattice: &Lattice) -> Vec<NodeId> {
492        let mut path = Vec::new();
493        let mut current_id = lattice.eos().id;
494
495        while current_id != INVALID_NODE_ID {
496            if let Some(node) = lattice.node(current_id) {
497                // BOS, EOS는 결과에서 제외
498                if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
499                    path.push(current_id);
500                }
501                current_id = node.prev_node_id;
502            } else {
503                break;
504            }
505        }
506
507        path.reverse();
508        path
509    }
510
511    /// 최적 경로의 총 비용 조회
512    #[must_use]
513    pub fn get_best_cost(&self, lattice: &Lattice) -> i32 {
514        lattice.eos().total_cost
515    }
516
517    /// 경로가 유효한지 확인
518    ///
519    /// EOS까지의 경로가 존재하는지 확인합니다.
520    #[must_use]
521    pub fn has_valid_path(&self, lattice: &Lattice) -> bool {
522        lattice.eos().total_cost != i32::MAX && lattice.eos().prev_node_id != INVALID_NODE_ID
523    }
524}
525
526// ============================================
527// N-best 지원
528// ============================================
529
530/// N-best 경로 노드 (링크드 리스트)
531///
532/// 경로를 Rc로 연결하여 클론 비용을 줄입니다.
533/// 전체 경로를 복사하는 대신 참조 카운트만 증가시킵니다.
534#[derive(Debug, Clone)]
535struct PathNode {
536    /// 현재 노드 ID
537    node_id: NodeId,
538    /// 이전 경로 노드 (Rc로 공유)
539    prev: Option<Rc<Self>>,
540}
541
542impl PathNode {
543    /// 새 경로 노드 생성
544    ///
545    /// Note: Cannot be const due to `Rc<Self>` parameter
546    #[allow(clippy::missing_const_for_fn)]
547    fn new(node_id: NodeId, prev: Option<Rc<Self>>) -> Self {
548        Self { node_id, prev }
549    }
550
551    /// 경로를 Vec로 변환 (BOS에서 현재 노드까지)
552    fn to_vec(&self) -> Vec<NodeId> {
553        let mut path = Vec::new();
554        let mut current = Some(self);
555
556        while let Some(node) = current {
557            path.push(node.node_id);
558            current = node.prev.as_ref().map(std::convert::AsRef::as_ref);
559        }
560
561        path.reverse();
562        path
563    }
564}
565
566/// N-best 경로 후보
567#[derive(Debug, Clone)]
568struct NbestCandidate {
569    /// 노드 ID
570    node_id: NodeId,
571    /// 총 비용
572    cost: i32,
573    /// 이전 경로 (Rc로 공유되는 링크드 리스트)
574    path: Option<Rc<PathNode>>,
575}
576
577impl Eq for NbestCandidate {}
578
579impl PartialEq for NbestCandidate {
580    fn eq(&self, other: &Self) -> bool {
581        self.cost == other.cost
582    }
583}
584
585impl Ord for NbestCandidate {
586    fn cmp(&self, other: &Self) -> Ordering {
587        // Min-heap: 비용이 낮은 것이 우선
588        other.cost.cmp(&self.cost)
589    }
590}
591
592impl PartialOrd for NbestCandidate {
593    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
594        Some(self.cmp(other))
595    }
596}
597
598/// N-best 탐색기
599///
600/// 상위 N개의 최적 경로를 찾습니다.
601#[derive(Debug, Clone)]
602pub struct NbestSearcher {
603    /// 기본 Viterbi 탐색기
604    viterbi: ViterbiSearcher,
605    /// 최대 결과 수
606    max_results: usize,
607}
608
609impl NbestSearcher {
610    /// 새 N-best 탐색기 생성
611    #[must_use]
612    pub fn new(n: usize) -> Self {
613        Self {
614            viterbi: ViterbiSearcher::new(),
615            max_results: n,
616        }
617    }
618
619    /// 띄어쓰기 패널티 설정
620    #[must_use]
621    pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
622        self.viterbi.space_penalty = penalty;
623        self
624    }
625
626    /// N-best 경로 탐색
627    ///
628    /// # Arguments
629    ///
630    /// * `lattice` - 노드가 추가된 Lattice
631    /// * `conn_cost` - 연접 비용 조회 인터페이스
632    ///
633    /// # Returns
634    ///
635    /// (경로, 비용) 쌍의 벡터, 비용 오름차순
636    pub fn search<C: ConnectionCost>(
637        &self,
638        lattice: &mut Lattice,
639        conn_cost: &C,
640    ) -> Vec<(Vec<NodeId>, i32)> {
641        // 먼저 Forward pass 실행
642        self.viterbi.forward_pass(lattice, conn_cost);
643
644        // 최적 경로가 없으면 빈 결과 반환
645        if !self.viterbi.has_valid_path(lattice) {
646            return Vec::new();
647        }
648
649        // 1-best인 경우 단순 backward pass
650        if self.max_results == 1 {
651            let path = ViterbiSearcher::backward_pass(lattice);
652            let cost = self.viterbi.get_best_cost(lattice);
653            return vec![(path, cost)];
654        }
655
656        // N-best: A* 유사 알고리즘
657        self.search_nbest(lattice, conn_cost)
658    }
659
660    /// N-best 경로 탐색 (A* 기반)
661    ///
662    /// # 최적화
663    ///
664    /// 경로를 `Rc<PathNode>`로 표현하여 클론 비용을 최소화합니다.
665    /// 전체 Vec를 복사하는 대신 참조 카운트만 증가시켜 O(1) 클론을 달성합니다.
666    fn search_nbest<C: ConnectionCost>(
667        &self,
668        lattice: &Lattice,
669        _conn_cost: &C,
670    ) -> Vec<(Vec<NodeId>, i32)> {
671        let mut results: Vec<(Vec<NodeId>, i32)> = Vec::new();
672        let mut heap: BinaryHeap<NbestCandidate> = BinaryHeap::new();
673
674        // EOS에서 시작
675        let eos = lattice.eos();
676        if eos.total_cost == i32::MAX {
677            return results;
678        }
679
680        heap.push(NbestCandidate {
681            node_id: eos.id,
682            cost: eos.total_cost,
683            path: None,
684        });
685
686        while let Some(candidate) = heap.pop() {
687            if results.len() >= self.max_results {
688                break;
689            }
690
691            let Some(node) = lattice.node(candidate.node_id) else {
692                continue;
693            };
694
695            // 현재까지의 경로 (Rc 클론은 O(1))
696            let current_path = if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos
697            {
698                // BOS, EOS가 아니면 경로에 추가
699                Some(Rc::new(PathNode::new(candidate.node_id, candidate.path)))
700            } else {
701                candidate.path
702            };
703
704            // BOS에 도달하면 결과에 추가
705            if node.node_type == NodeType::Bos {
706                // 경로를 Vec로 변환 (완료된 경로만)
707                let path_vec = current_path.map_or_else(Vec::new, |path_node| path_node.to_vec());
708                results.push((path_vec, candidate.cost));
709                continue;
710            }
711
712            // 이전 노드로 계속 탐색
713            if node.prev_node_id != INVALID_NODE_ID {
714                heap.push(NbestCandidate {
715                    node_id: node.prev_node_id,
716                    cost: candidate.cost,
717                    path: current_path,
718                });
719            }
720        }
721
722        results
723    }
724}
725
726/// Viterbi 결과를 Token으로 변환하는 헬퍼
727pub struct ViterbiResult<'a> {
728    /// Lattice 참조
729    lattice: &'a Lattice,
730    /// 최적 경로 노드 ID
731    path: Vec<NodeId>,
732    /// 총 비용
733    total_cost: i32,
734}
735
736impl<'a> ViterbiResult<'a> {
737    /// 결과 생성
738    #[must_use]
739    pub const fn new(lattice: &'a Lattice, path: Vec<NodeId>, total_cost: i32) -> Self {
740        Self {
741            lattice,
742            path,
743            total_cost,
744        }
745    }
746
747    /// 경로의 노드들 반복
748    pub fn nodes(&self) -> impl Iterator<Item = &'a Node> + '_ {
749        self.path.iter().filter_map(|&id| self.lattice.node(id))
750    }
751
752    /// 총 비용
753    #[must_use]
754    pub const fn cost(&self) -> i32 {
755        self.total_cost
756    }
757
758    /// 노드 개수
759    #[must_use]
760    pub fn len(&self) -> usize {
761        self.path.len()
762    }
763
764    /// 비어있는지 확인
765    #[must_use]
766    pub fn is_empty(&self) -> bool {
767        self.path.is_empty()
768    }
769
770    /// 표면형 목록
771    #[must_use]
772    pub fn surfaces(&self) -> Vec<&str> {
773        self.nodes().map(|n| n.surface.as_ref()).collect()
774    }
775}
776
777#[cfg(test)]
778#[allow(clippy::unwrap_used)]
779mod tests {
780    use super::*;
781    use crate::lattice::NodeBuilder;
782
783    /// 테스트용 연접 비용 행렬
784    struct TestConnectionCost {
785        costs: std::collections::HashMap<(u16, u16), i32>,
786        default: i32,
787    }
788
789    impl TestConnectionCost {
790        fn new(default: i32) -> Self {
791            Self {
792                costs: std::collections::HashMap::new(),
793                default,
794            }
795        }
796
797        fn set(&mut self, right_id: u16, left_id: u16, cost: i32) {
798            self.costs.insert((right_id, left_id), cost);
799        }
800    }
801
802    impl ConnectionCost for TestConnectionCost {
803        fn cost(&self, right_id: u16, left_id: u16) -> i32 {
804            self.costs
805                .get(&(right_id, left_id))
806                .copied()
807                .unwrap_or(self.default)
808        }
809    }
810
811    #[test]
812    fn test_space_penalty_default() {
813        let penalty = SpacePenalty::default();
814        assert!(penalty.is_empty());
815        assert_eq!(penalty.get(100), 0);
816    }
817
818    #[test]
819    fn test_space_penalty_from_dicrc() {
820        let penalty = SpacePenalty::from_dicrc("100,5000;200,3000;300,1000");
821
822        assert_eq!(penalty.len(), 3);
823        assert_eq!(penalty.get(100), 5000);
824        assert_eq!(penalty.get(200), 3000);
825        assert_eq!(penalty.get(300), 1000);
826        assert_eq!(penalty.get(999), 0); // 미등록
827    }
828
829    #[test]
830    fn test_space_penalty_korean_default() {
831        let penalty = SpacePenalty::korean_default();
832        assert!(!penalty.is_empty());
833
834        // 조사 범위에 대해 페널티가 설정되어 있어야 함
835        assert!(penalty.get(1785) > 0);
836    }
837
838    #[test]
839    fn test_viterbi_simple_path() {
840        // 간단한 Lattice: "AB"
841        // BOS -> [A] -> [B] -> EOS
842        let mut lattice = Lattice::new("AB");
843
844        // A 노드 (위치 0-1)
845        lattice.add_node(
846            NodeBuilder::new("A", 0, 1)
847                .left_id(1)
848                .right_id(1)
849                .word_cost(100),
850        );
851
852        // B 노드 (위치 1-2)
853        lattice.add_node(
854            NodeBuilder::new("B", 1, 2)
855                .left_id(2)
856                .right_id(2)
857                .word_cost(200),
858        );
859
860        let conn_cost = ZeroConnectionCost;
861        let searcher = ViterbiSearcher::new();
862
863        let path = searcher.search(&mut lattice, &conn_cost);
864
865        assert_eq!(path.len(), 2);
866
867        // 첫 번째 노드는 "A"
868        let first = lattice.node(path[0]).unwrap();
869        assert_eq!(first.surface.as_ref(), "A");
870
871        // 두 번째 노드는 "B"
872        let second = lattice.node(path[1]).unwrap();
873        assert_eq!(second.surface.as_ref(), "B");
874
875        // 총 비용 확인
876        let total_cost = searcher.get_best_cost(&lattice);
877        assert_eq!(total_cost, 300); // 100 + 200
878    }
879
880    #[test]
881    fn test_viterbi_choose_best_path() {
882        // 두 가지 경로가 있는 Lattice: "AB"
883        // 경로 1: BOS -> [AB] -> EOS (비용: 500)
884        // 경로 2: BOS -> [A] -> [B] -> EOS (비용: 100 + 200 = 300)
885        let mut lattice = Lattice::new("AB");
886
887        // AB 노드 (위치 0-2) - 비용 높음
888        lattice.add_node(
889            NodeBuilder::new("AB", 0, 2)
890                .left_id(1)
891                .right_id(1)
892                .word_cost(500),
893        );
894
895        // A 노드 (위치 0-1)
896        lattice.add_node(
897            NodeBuilder::new("A", 0, 1)
898                .left_id(2)
899                .right_id(2)
900                .word_cost(100),
901        );
902
903        // B 노드 (위치 1-2)
904        lattice.add_node(
905            NodeBuilder::new("B", 1, 2)
906                .left_id(3)
907                .right_id(3)
908                .word_cost(200),
909        );
910
911        let conn_cost = ZeroConnectionCost;
912        let searcher = ViterbiSearcher::new();
913
914        let path = searcher.search(&mut lattice, &conn_cost);
915
916        // 더 낮은 비용의 경로 선택: A + B
917        assert_eq!(path.len(), 2);
918        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "A");
919        assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "B");
920    }
921
922    #[test]
923    fn test_viterbi_with_connection_cost() {
924        // 연접 비용이 경로 선택에 영향
925        // 경로 1: BOS -> [AB] -> EOS (단어: 300, 연접: 0)
926        // 경로 2: BOS -> [A] -> [B] -> EOS (단어: 100+100=200, 연접: 500)
927        let mut lattice = Lattice::new("AB");
928
929        // AB 노드
930        lattice.add_node(
931            NodeBuilder::new("AB", 0, 2)
932                .left_id(1)
933                .right_id(1)
934                .word_cost(300),
935        );
936
937        // A 노드
938        lattice.add_node(
939            NodeBuilder::new("A", 0, 1)
940                .left_id(2)
941                .right_id(2)
942                .word_cost(100),
943        );
944
945        // B 노드
946        lattice.add_node(
947            NodeBuilder::new("B", 1, 2)
948                .left_id(3)
949                .right_id(3)
950                .word_cost(100),
951        );
952
953        let mut conn_cost = TestConnectionCost::new(0);
954        // A -> B 연접에 높은 비용 설정
955        conn_cost.set(2, 3, 500);
956
957        let searcher = ViterbiSearcher::new();
958        let path = searcher.search(&mut lattice, &conn_cost);
959
960        // 연접 비용 때문에 AB 선택: 300 < 200 + 500
961        assert_eq!(path.len(), 1);
962        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
963    }
964
965    #[test]
966    fn test_viterbi_with_space_penalty() {
967        // 띄어쓰기 패널티 테스트
968        // "A B" (공백 있음)
969        // B의 left_id에 패널티가 있으면 다른 경로 선택
970        let mut lattice = Lattice::new("A B");
971        // 공백 제거 후 "AB"
972
973        // AB 노드 (전체)
974        lattice.add_node(
975            NodeBuilder::new("AB", 0, 2)
976                .left_id(1)
977                .right_id(1)
978                .word_cost(500),
979        );
980
981        // A 노드
982        lattice.add_node(
983            NodeBuilder::new("A", 0, 1)
984                .left_id(2)
985                .right_id(2)
986                .word_cost(100),
987        );
988
989        // B 노드 (공백 뒤에서 시작)
990        lattice.add_node(
991            NodeBuilder::new("B", 1, 2)
992                .left_id(100) // 페널티가 적용될 ID
993                .right_id(3)
994                .word_cost(100)
995                .has_space_before(true),
996        );
997
998        // B의 left_id에 높은 페널티 설정
999        let mut penalty = SpacePenalty::new();
1000        penalty.add(100, 1000);
1001
1002        let conn_cost = ZeroConnectionCost;
1003        let searcher = ViterbiSearcher::new().with_space_penalty(penalty);
1004
1005        let path = searcher.search(&mut lattice, &conn_cost);
1006
1007        // 페널티 때문에 AB 선택: 500 < 100 + 100 + 1000
1008        assert_eq!(path.len(), 1);
1009        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
1010    }
1011
1012    #[test]
1013    fn test_viterbi_korean_example() {
1014        // 한국어 예시: "아버지가"
1015        let mut lattice = Lattice::new("아버지가");
1016
1017        // 경로 1: "아버지" + "가" (조사)
1018        lattice.add_node(
1019            NodeBuilder::new("아버지", 0, 3)
1020                .left_id(1)
1021                .right_id(1)
1022                .word_cost(1000),
1023        );
1024        lattice.add_node(
1025            NodeBuilder::new("가", 3, 4)
1026                .left_id(100) // 조사
1027                .right_id(100)
1028                .word_cost(500),
1029        );
1030
1031        // 경로 2: "아버" + "지가"
1032        lattice.add_node(
1033            NodeBuilder::new("아버", 0, 2)
1034                .left_id(2)
1035                .right_id(2)
1036                .word_cost(3000),
1037        );
1038        lattice.add_node(
1039            NodeBuilder::new("지가", 2, 4)
1040                .left_id(3)
1041                .right_id(3)
1042                .word_cost(3000),
1043        );
1044
1045        let conn_cost = ZeroConnectionCost;
1046        let searcher = ViterbiSearcher::new();
1047
1048        let path = searcher.search(&mut lattice, &conn_cost);
1049
1050        // "아버지" + "가" 선택 (비용: 1500 < 6000)
1051        assert_eq!(path.len(), 2);
1052        assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "아버지");
1053        assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "가");
1054    }
1055
1056    #[test]
1057    fn test_viterbi_empty_lattice() {
1058        let mut lattice = Lattice::new("");
1059
1060        let conn_cost = ZeroConnectionCost;
1061        let searcher = ViterbiSearcher::new();
1062
1063        let path = searcher.search(&mut lattice, &conn_cost);
1064
1065        // 빈 텍스트는 빈 경로
1066        assert!(path.is_empty());
1067    }
1068
1069    #[test]
1070    fn test_viterbi_no_path() {
1071        // 노드가 연결되지 않는 경우
1072        let mut lattice = Lattice::new("ABC");
1073
1074        // A만 있고 B, C 없음 -> EOS에 도달 불가
1075        lattice.add_node(
1076            NodeBuilder::new("A", 0, 1)
1077                .left_id(1)
1078                .right_id(1)
1079                .word_cost(100),
1080        );
1081
1082        let conn_cost = ZeroConnectionCost;
1083        let searcher = ViterbiSearcher::new();
1084
1085        let path = searcher.search(&mut lattice, &conn_cost);
1086
1087        // 유효한 경로 없음
1088        assert!(!searcher.has_valid_path(&lattice));
1089        assert!(path.is_empty());
1090    }
1091
1092    #[test]
1093    fn test_nbest_single() {
1094        let mut lattice = Lattice::new("AB");
1095
1096        lattice.add_node(
1097            NodeBuilder::new("A", 0, 1)
1098                .left_id(1)
1099                .right_id(1)
1100                .word_cost(100),
1101        );
1102        lattice.add_node(
1103            NodeBuilder::new("B", 1, 2)
1104                .left_id(2)
1105                .right_id(2)
1106                .word_cost(200),
1107        );
1108
1109        let conn_cost = ZeroConnectionCost;
1110        let searcher = NbestSearcher::new(1);
1111
1112        let results = searcher.search(&mut lattice, &conn_cost);
1113
1114        assert_eq!(results.len(), 1);
1115        assert_eq!(results[0].1, 300); // 비용
1116    }
1117
1118    #[test]
1119    fn test_viterbi_result_helper() {
1120        let mut lattice = Lattice::new("AB");
1121
1122        let _id1 = lattice.add_node(
1123            NodeBuilder::new("A", 0, 1)
1124                .left_id(1)
1125                .right_id(1)
1126                .word_cost(100),
1127        );
1128        let _id2 = lattice.add_node(
1129            NodeBuilder::new("B", 1, 2)
1130                .left_id(2)
1131                .right_id(2)
1132                .word_cost(200),
1133        );
1134
1135        let conn_cost = ZeroConnectionCost;
1136        let searcher = ViterbiSearcher::new();
1137        let path = searcher.search(&mut lattice, &conn_cost);
1138        let cost = searcher.get_best_cost(&lattice);
1139
1140        let result = ViterbiResult::new(&lattice, path, cost);
1141
1142        assert_eq!(result.len(), 2);
1143        assert_eq!(result.cost(), 300);
1144        assert_eq!(result.surfaces(), vec!["A", "B"]);
1145    }
1146
1147    #[test]
1148    fn test_viterbi_with_dense_matrix() {
1149        use mecab_ko_dict::DenseMatrix;
1150
1151        // 3x3 연접 비용 행렬 생성
1152        // left_id: 0=BOS, 1=명사, 2=조사
1153        // right_id: 0=EOS, 1=명사, 2=조사
1154        let mut matrix = DenseMatrix::new(3, 3, 0);
1155
1156        // 연접 비용 설정
1157        // BOS -> 명사: 낮은 비용 (자연스러움)
1158        matrix.set(0, 1, 100);
1159        // 명사 -> 조사: 낮은 비용 (자연스러움)
1160        matrix.set(1, 2, 50);
1161        // 조사 -> EOS: 낮은 비용
1162        matrix.set(2, 0, 30);
1163
1164        // BOS -> 조사: 높은 비용 (부자연스러움)
1165        matrix.set(0, 2, 5000);
1166        // 명사 -> EOS: 중간 비용
1167        matrix.set(1, 0, 200);
1168
1169        let mut lattice = Lattice::new("책을");
1170
1171        // "책" (명사) - 문자 위치 0..1
1172        lattice.add_node(
1173            NodeBuilder::new("책", 0, 1)
1174                .left_id(1) // 명사 left_id
1175                .right_id(1) // 명사 right_id
1176                .word_cost(500),
1177        );
1178
1179        // "을" (조사) - 문자 위치 1..2
1180        lattice.add_node(
1181            NodeBuilder::new("을", 1, 2)
1182                .left_id(2) // 조사 left_id
1183                .right_id(2) // 조사 right_id
1184                .word_cost(100),
1185        );
1186
1187        let searcher = ViterbiSearcher::new();
1188        let path = searcher.search(&mut lattice, &matrix);
1189
1190        // BOS -> 명사 -> 조사 -> EOS 경로 확인
1191        assert!(!path.is_empty());
1192
1193        let result = ViterbiResult::new(&lattice, path, searcher.get_best_cost(&lattice));
1194        assert_eq!(result.surfaces(), vec!["책", "을"]);
1195
1196        // 총 비용: BOS->명사(100) + 명사비용(500) + 명사->조사(50) + 조사비용(100) + 조사->EOS(30)
1197        // = 100 + 500 + 50 + 100 + 30 = 780
1198        assert_eq!(result.cost(), 780);
1199    }
1200}