1use crate::lattice::{Lattice, Node, NodeId, NodeType, INVALID_NODE_ID};
37use std::cmp::Ordering;
38use std::collections::BinaryHeap;
39use std::rc::Rc;
40
41#[cfg(feature = "simd")]
43pub mod simd;
44
45#[cfg(feature = "simd")]
46pub use simd::{simd_forward_pass_position, simd_update_node_cost};
47
48#[inline(always)]
52const fn saturating_add_chain(a: i32, b: i32, c: i32, d: i32) -> i32 {
53 a.saturating_add(b).saturating_add(c).saturating_add(d)
54}
55
56pub(crate) const DEFAULT_OOB_CONNECTION_COST: i32 = 10_000;
63
64#[inline(always)]
65pub(crate) const fn clamp_oob_cost(raw: i32) -> i32 {
66 if raw == i32::MAX {
67 DEFAULT_OOB_CONNECTION_COST
68 } else {
69 raw
70 }
71}
72
73pub trait ConnectionCost {
78 fn cost(&self, right_id: u16, left_id: u16) -> i32;
89}
90
91#[derive(Debug, Clone, Default)]
95pub struct ZeroConnectionCost;
96
97impl ConnectionCost for ZeroConnectionCost {
98 #[inline(always)]
99 fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
100 0
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct FixedConnectionCost {
107 pub default_cost: i32,
109}
110
111impl FixedConnectionCost {
112 #[must_use]
114 pub const fn new(cost: i32) -> Self {
115 Self { default_cost: cost }
116 }
117}
118
119impl ConnectionCost for FixedConnectionCost {
120 #[inline(always)]
121 fn cost(&self, _right_id: u16, _left_id: u16) -> i32 {
122 self.default_cost
123 }
124}
125
126impl<T: mecab_ko_dict::Matrix> ConnectionCost for T {
130 #[inline(always)]
131 fn cost(&self, right_id: u16, left_id: u16) -> i32 {
132 self.get(right_id, left_id)
133 }
134}
135
136#[derive(Debug, Clone, Default)]
153pub struct SpacePenalty {
154 penalties: Vec<(u16, i32)>,
157}
158
159impl SpacePenalty {
160 #[must_use]
162 pub fn new() -> Self {
163 Self::default()
164 }
165
166 #[must_use]
171 pub fn korean_default() -> Self {
172 let mut penalties: Vec<(u16, i32)> = (1700u16..1760)
179 .chain(1780..1810)
180 .map(|id| (id, 6000))
181 .collect();
182
183 penalties.sort_unstable_by_key(|&(id, _)| id);
185 Self { penalties }
186 }
187
188 #[must_use]
205 pub fn from_dicrc(config: &str) -> Self {
206 let mut penalties = Vec::new();
207
208 for part in config.split(';') {
209 let parts: Vec<&str> = part.trim().split(',').collect();
210 if parts.len() == 2 {
211 if let (Ok(id), Ok(penalty)) = (
212 parts[0].trim().parse::<u16>(),
213 parts[1].trim().parse::<i32>(),
214 ) {
215 penalties.push((id, penalty));
216 }
217 }
218 }
219
220 penalties.sort_unstable_by_key(|&(id, _)| id);
222 Self { penalties }
223 }
224
225 pub fn add(&mut self, left_id: u16, penalty: i32) {
227 let pos = self.penalties.partition_point(|&(id, _)| id < left_id);
229 self.penalties.insert(pos, (left_id, penalty));
230 }
231
232 #[must_use]
238 #[inline]
239 pub fn get(&self, left_id: u16) -> i32 {
240 self.penalties
242 .binary_search_by_key(&left_id, |&(id, _)| id)
243 .map_or(0, |idx| self.penalties[idx].1)
244 }
245
246 #[must_use]
248 #[inline]
249 pub fn is_empty(&self) -> bool {
250 self.penalties.is_empty()
251 }
252
253 #[must_use]
255 #[inline]
256 pub fn len(&self) -> usize {
257 self.penalties.len()
258 }
259}
260
261#[derive(Debug, Clone)]
265pub struct ViterbiSearcher {
266 pub space_penalty: SpacePenalty,
268}
269
270impl Default for ViterbiSearcher {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276impl ViterbiSearcher {
277 #[must_use]
279 pub fn new() -> Self {
280 Self {
281 space_penalty: SpacePenalty::default(),
282 }
283 }
284
285 #[must_use]
287 pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
288 self.space_penalty = penalty;
289 self
290 }
291
292 pub fn search<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) -> Vec<NodeId> {
318 self.forward_pass(lattice, conn_cost);
320
321 Self::backward_pass(lattice)
323 }
324
325 fn forward_pass<C: ConnectionCost>(&self, lattice: &mut Lattice, conn_cost: &C) {
329 let char_len = lattice.char_len();
330
331 let mut starting_ids: Vec<NodeId> = Vec::new();
335 let mut ending_nodes: Vec<(NodeId, i32, u16)> = Vec::new();
336
337 for pos in 0..=char_len {
339 starting_ids.clear();
341 starting_ids.extend(lattice.nodes_starting_at(pos).map(|n| n.id));
342
343 ending_nodes.clear();
346 ending_nodes.extend(
347 lattice
348 .nodes_ending_at(pos)
349 .map(|n| (n.id, n.total_cost, n.right_id)),
350 );
351
352 for &node_id in &starting_ids {
353 self.update_node_cost_with_endings(lattice, conn_cost, node_id, &ending_nodes);
354 }
355 }
356 }
357
358 #[inline]
363 fn update_node_cost_with_endings<C: ConnectionCost>(
364 &self,
365 lattice: &mut Lattice,
366 conn_cost: &C,
367 node_id: NodeId,
368 ending_nodes: &[(NodeId, i32, u16)],
369 ) {
370 #[cfg(feature = "simd")]
372 if ending_nodes.len() >= 8 {
373 let (best_cost, best_prev) = simd::simd_update_node_cost(
374 lattice,
375 conn_cost,
376 node_id,
377 ending_nodes,
378 &self.space_penalty,
379 );
380 if let Some(node) = lattice.node_mut(node_id) {
381 node.total_cost = best_cost;
382 node.prev_node_id = best_prev;
383 }
384 return;
385 }
386
387 let (left_id, word_cost, has_space) = {
389 let Some(node) = lattice.node(node_id) else {
390 return;
391 };
392 (node.left_id, node.word_cost, node.has_space_before)
393 };
394
395 let space_penalty = if has_space {
397 self.space_penalty.get(left_id)
398 } else {
399 0
400 };
401
402 let mut best_cost = i32::MAX;
403 let mut best_prev = INVALID_NODE_ID;
404
405 for &(prev_id, prev_cost, prev_right_id) in ending_nodes {
406 if prev_cost == i32::MAX {
408 continue;
409 }
410
411 let connection = clamp_oob_cost(conn_cost.cost(prev_right_id, left_id));
413
414 let total = saturating_add_chain(prev_cost, connection, word_cost, space_penalty);
416
417 if total < best_cost {
418 best_cost = total;
419 best_prev = prev_id;
420 }
421 }
422
423 if let Some(node) = lattice.node_mut(node_id) {
425 node.total_cost = best_cost;
426 node.prev_node_id = best_prev;
427 }
428 }
429
430 #[cfg(test)]
432 #[allow(dead_code)]
433 fn update_node_cost<C: ConnectionCost>(
434 &self,
435 lattice: &mut Lattice,
436 conn_cost: &C,
437 node_id: NodeId,
438 pos: usize,
439 ) {
440 let (left_id, word_cost, has_space) = {
442 let Some(node) = lattice.node(node_id) else {
443 return;
444 };
445 (node.left_id, node.word_cost, node.has_space_before)
446 };
447
448 let ending_nodes: Vec<(NodeId, i32, u16)> = lattice
450 .nodes_ending_at(pos)
451 .map(|n| (n.id, n.total_cost, n.right_id))
452 .collect();
453
454 let mut best_cost = i32::MAX;
455 let mut best_prev = INVALID_NODE_ID;
456
457 for (prev_id, prev_cost, prev_right_id) in ending_nodes {
458 if prev_cost == i32::MAX {
459 continue;
460 }
461
462 let connection = clamp_oob_cost(conn_cost.cost(prev_right_id, left_id));
463
464 let space_penalty = if has_space {
465 self.space_penalty.get(left_id)
466 } else {
467 0
468 };
469
470 let total = prev_cost
471 .saturating_add(connection)
472 .saturating_add(word_cost)
473 .saturating_add(space_penalty);
474
475 if total < best_cost {
476 best_cost = total;
477 best_prev = prev_id;
478 }
479 }
480
481 if let Some(node) = lattice.node_mut(node_id) {
483 node.total_cost = best_cost;
484 node.prev_node_id = best_prev;
485 }
486 }
487
488 fn backward_pass(lattice: &Lattice) -> Vec<NodeId> {
492 let mut path = Vec::new();
493 let mut current_id = lattice.eos().id;
494
495 while current_id != INVALID_NODE_ID {
496 if let Some(node) = lattice.node(current_id) {
497 if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
499 path.push(current_id);
500 }
501 current_id = node.prev_node_id;
502 } else {
503 break;
504 }
505 }
506
507 path.reverse();
508 path
509 }
510
511 #[must_use]
513 pub fn get_best_cost(&self, lattice: &Lattice) -> i32 {
514 lattice.eos().total_cost
515 }
516
517 #[must_use]
521 pub fn has_valid_path(&self, lattice: &Lattice) -> bool {
522 lattice.eos().total_cost != i32::MAX && lattice.eos().prev_node_id != INVALID_NODE_ID
523 }
524}
525
526#[derive(Debug, Clone)]
535struct PathNode {
536 node_id: NodeId,
538 prev: Option<Rc<Self>>,
540}
541
542impl PathNode {
543 #[allow(clippy::missing_const_for_fn)]
547 fn new(node_id: NodeId, prev: Option<Rc<Self>>) -> Self {
548 Self { node_id, prev }
549 }
550
551 fn to_vec(&self) -> Vec<NodeId> {
553 let mut path = Vec::new();
554 let mut current = Some(self);
555
556 while let Some(node) = current {
557 path.push(node.node_id);
558 current = node.prev.as_ref().map(std::convert::AsRef::as_ref);
559 }
560
561 path.reverse();
562 path
563 }
564}
565
566#[derive(Debug, Clone)]
568struct NbestCandidate {
569 node_id: NodeId,
571 cost: i32,
573 path: Option<Rc<PathNode>>,
575}
576
577impl Eq for NbestCandidate {}
578
579impl PartialEq for NbestCandidate {
580 fn eq(&self, other: &Self) -> bool {
581 self.cost == other.cost
582 }
583}
584
585impl Ord for NbestCandidate {
586 fn cmp(&self, other: &Self) -> Ordering {
587 other.cost.cmp(&self.cost)
589 }
590}
591
592impl PartialOrd for NbestCandidate {
593 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
594 Some(self.cmp(other))
595 }
596}
597
598#[derive(Debug, Clone)]
602pub struct NbestSearcher {
603 viterbi: ViterbiSearcher,
605 max_results: usize,
607}
608
609impl NbestSearcher {
610 #[must_use]
612 pub fn new(n: usize) -> Self {
613 Self {
614 viterbi: ViterbiSearcher::new(),
615 max_results: n,
616 }
617 }
618
619 #[must_use]
621 pub fn with_space_penalty(mut self, penalty: SpacePenalty) -> Self {
622 self.viterbi.space_penalty = penalty;
623 self
624 }
625
626 pub fn search<C: ConnectionCost>(
637 &self,
638 lattice: &mut Lattice,
639 conn_cost: &C,
640 ) -> Vec<(Vec<NodeId>, i32)> {
641 self.viterbi.forward_pass(lattice, conn_cost);
643
644 if !self.viterbi.has_valid_path(lattice) {
646 return Vec::new();
647 }
648
649 if self.max_results == 1 {
651 let path = ViterbiSearcher::backward_pass(lattice);
652 let cost = self.viterbi.get_best_cost(lattice);
653 return vec![(path, cost)];
654 }
655
656 self.search_nbest(lattice, conn_cost)
658 }
659
660 fn search_nbest<C: ConnectionCost>(
667 &self,
668 lattice: &Lattice,
669 _conn_cost: &C,
670 ) -> Vec<(Vec<NodeId>, i32)> {
671 let mut results: Vec<(Vec<NodeId>, i32)> = Vec::new();
672 let mut heap: BinaryHeap<NbestCandidate> = BinaryHeap::new();
673
674 let eos = lattice.eos();
676 if eos.total_cost == i32::MAX {
677 return results;
678 }
679
680 heap.push(NbestCandidate {
681 node_id: eos.id,
682 cost: eos.total_cost,
683 path: None,
684 });
685
686 while let Some(candidate) = heap.pop() {
687 if results.len() >= self.max_results {
688 break;
689 }
690
691 let Some(node) = lattice.node(candidate.node_id) else {
692 continue;
693 };
694
695 let current_path = if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos
697 {
698 Some(Rc::new(PathNode::new(candidate.node_id, candidate.path)))
700 } else {
701 candidate.path
702 };
703
704 if node.node_type == NodeType::Bos {
706 let path_vec = current_path.map_or_else(Vec::new, |path_node| path_node.to_vec());
708 results.push((path_vec, candidate.cost));
709 continue;
710 }
711
712 if node.prev_node_id != INVALID_NODE_ID {
714 heap.push(NbestCandidate {
715 node_id: node.prev_node_id,
716 cost: candidate.cost,
717 path: current_path,
718 });
719 }
720 }
721
722 results
723 }
724}
725
726pub struct ViterbiResult<'a> {
728 lattice: &'a Lattice,
730 path: Vec<NodeId>,
732 total_cost: i32,
734}
735
736impl<'a> ViterbiResult<'a> {
737 #[must_use]
739 pub const fn new(lattice: &'a Lattice, path: Vec<NodeId>, total_cost: i32) -> Self {
740 Self {
741 lattice,
742 path,
743 total_cost,
744 }
745 }
746
747 pub fn nodes(&self) -> impl Iterator<Item = &'a Node> + '_ {
749 self.path.iter().filter_map(|&id| self.lattice.node(id))
750 }
751
752 #[must_use]
754 pub const fn cost(&self) -> i32 {
755 self.total_cost
756 }
757
758 #[must_use]
760 pub fn len(&self) -> usize {
761 self.path.len()
762 }
763
764 #[must_use]
766 pub fn is_empty(&self) -> bool {
767 self.path.is_empty()
768 }
769
770 #[must_use]
772 pub fn surfaces(&self) -> Vec<&str> {
773 self.nodes().map(|n| n.surface.as_ref()).collect()
774 }
775}
776
777#[cfg(test)]
778#[allow(clippy::unwrap_used)]
779mod tests {
780 use super::*;
781 use crate::lattice::NodeBuilder;
782
783 struct TestConnectionCost {
785 costs: std::collections::HashMap<(u16, u16), i32>,
786 default: i32,
787 }
788
789 impl TestConnectionCost {
790 fn new(default: i32) -> Self {
791 Self {
792 costs: std::collections::HashMap::new(),
793 default,
794 }
795 }
796
797 fn set(&mut self, right_id: u16, left_id: u16, cost: i32) {
798 self.costs.insert((right_id, left_id), cost);
799 }
800 }
801
802 impl ConnectionCost for TestConnectionCost {
803 fn cost(&self, right_id: u16, left_id: u16) -> i32 {
804 self.costs
805 .get(&(right_id, left_id))
806 .copied()
807 .unwrap_or(self.default)
808 }
809 }
810
811 #[test]
812 fn test_space_penalty_default() {
813 let penalty = SpacePenalty::default();
814 assert!(penalty.is_empty());
815 assert_eq!(penalty.get(100), 0);
816 }
817
818 #[test]
819 fn test_space_penalty_from_dicrc() {
820 let penalty = SpacePenalty::from_dicrc("100,5000;200,3000;300,1000");
821
822 assert_eq!(penalty.len(), 3);
823 assert_eq!(penalty.get(100), 5000);
824 assert_eq!(penalty.get(200), 3000);
825 assert_eq!(penalty.get(300), 1000);
826 assert_eq!(penalty.get(999), 0); }
828
829 #[test]
830 fn test_space_penalty_korean_default() {
831 let penalty = SpacePenalty::korean_default();
832 assert!(!penalty.is_empty());
833
834 assert!(penalty.get(1785) > 0);
836 }
837
838 #[test]
839 fn test_viterbi_simple_path() {
840 let mut lattice = Lattice::new("AB");
843
844 lattice.add_node(
846 NodeBuilder::new("A", 0, 1)
847 .left_id(1)
848 .right_id(1)
849 .word_cost(100),
850 );
851
852 lattice.add_node(
854 NodeBuilder::new("B", 1, 2)
855 .left_id(2)
856 .right_id(2)
857 .word_cost(200),
858 );
859
860 let conn_cost = ZeroConnectionCost;
861 let searcher = ViterbiSearcher::new();
862
863 let path = searcher.search(&mut lattice, &conn_cost);
864
865 assert_eq!(path.len(), 2);
866
867 let first = lattice.node(path[0]).unwrap();
869 assert_eq!(first.surface.as_ref(), "A");
870
871 let second = lattice.node(path[1]).unwrap();
873 assert_eq!(second.surface.as_ref(), "B");
874
875 let total_cost = searcher.get_best_cost(&lattice);
877 assert_eq!(total_cost, 300); }
879
880 #[test]
881 fn test_viterbi_choose_best_path() {
882 let mut lattice = Lattice::new("AB");
886
887 lattice.add_node(
889 NodeBuilder::new("AB", 0, 2)
890 .left_id(1)
891 .right_id(1)
892 .word_cost(500),
893 );
894
895 lattice.add_node(
897 NodeBuilder::new("A", 0, 1)
898 .left_id(2)
899 .right_id(2)
900 .word_cost(100),
901 );
902
903 lattice.add_node(
905 NodeBuilder::new("B", 1, 2)
906 .left_id(3)
907 .right_id(3)
908 .word_cost(200),
909 );
910
911 let conn_cost = ZeroConnectionCost;
912 let searcher = ViterbiSearcher::new();
913
914 let path = searcher.search(&mut lattice, &conn_cost);
915
916 assert_eq!(path.len(), 2);
918 assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "A");
919 assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "B");
920 }
921
922 #[test]
923 fn test_viterbi_with_connection_cost() {
924 let mut lattice = Lattice::new("AB");
928
929 lattice.add_node(
931 NodeBuilder::new("AB", 0, 2)
932 .left_id(1)
933 .right_id(1)
934 .word_cost(300),
935 );
936
937 lattice.add_node(
939 NodeBuilder::new("A", 0, 1)
940 .left_id(2)
941 .right_id(2)
942 .word_cost(100),
943 );
944
945 lattice.add_node(
947 NodeBuilder::new("B", 1, 2)
948 .left_id(3)
949 .right_id(3)
950 .word_cost(100),
951 );
952
953 let mut conn_cost = TestConnectionCost::new(0);
954 conn_cost.set(2, 3, 500);
956
957 let searcher = ViterbiSearcher::new();
958 let path = searcher.search(&mut lattice, &conn_cost);
959
960 assert_eq!(path.len(), 1);
962 assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
963 }
964
965 #[test]
966 fn test_viterbi_with_space_penalty() {
967 let mut lattice = Lattice::new("A B");
971 lattice.add_node(
975 NodeBuilder::new("AB", 0, 2)
976 .left_id(1)
977 .right_id(1)
978 .word_cost(500),
979 );
980
981 lattice.add_node(
983 NodeBuilder::new("A", 0, 1)
984 .left_id(2)
985 .right_id(2)
986 .word_cost(100),
987 );
988
989 lattice.add_node(
991 NodeBuilder::new("B", 1, 2)
992 .left_id(100) .right_id(3)
994 .word_cost(100)
995 .has_space_before(true),
996 );
997
998 let mut penalty = SpacePenalty::new();
1000 penalty.add(100, 1000);
1001
1002 let conn_cost = ZeroConnectionCost;
1003 let searcher = ViterbiSearcher::new().with_space_penalty(penalty);
1004
1005 let path = searcher.search(&mut lattice, &conn_cost);
1006
1007 assert_eq!(path.len(), 1);
1009 assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "AB");
1010 }
1011
1012 #[test]
1013 fn test_viterbi_korean_example() {
1014 let mut lattice = Lattice::new("아버지가");
1016
1017 lattice.add_node(
1019 NodeBuilder::new("아버지", 0, 3)
1020 .left_id(1)
1021 .right_id(1)
1022 .word_cost(1000),
1023 );
1024 lattice.add_node(
1025 NodeBuilder::new("가", 3, 4)
1026 .left_id(100) .right_id(100)
1028 .word_cost(500),
1029 );
1030
1031 lattice.add_node(
1033 NodeBuilder::new("아버", 0, 2)
1034 .left_id(2)
1035 .right_id(2)
1036 .word_cost(3000),
1037 );
1038 lattice.add_node(
1039 NodeBuilder::new("지가", 2, 4)
1040 .left_id(3)
1041 .right_id(3)
1042 .word_cost(3000),
1043 );
1044
1045 let conn_cost = ZeroConnectionCost;
1046 let searcher = ViterbiSearcher::new();
1047
1048 let path = searcher.search(&mut lattice, &conn_cost);
1049
1050 assert_eq!(path.len(), 2);
1052 assert_eq!(lattice.node(path[0]).unwrap().surface.as_ref(), "아버지");
1053 assert_eq!(lattice.node(path[1]).unwrap().surface.as_ref(), "가");
1054 }
1055
1056 #[test]
1057 fn test_viterbi_empty_lattice() {
1058 let mut lattice = Lattice::new("");
1059
1060 let conn_cost = ZeroConnectionCost;
1061 let searcher = ViterbiSearcher::new();
1062
1063 let path = searcher.search(&mut lattice, &conn_cost);
1064
1065 assert!(path.is_empty());
1067 }
1068
1069 #[test]
1070 fn test_viterbi_no_path() {
1071 let mut lattice = Lattice::new("ABC");
1073
1074 lattice.add_node(
1076 NodeBuilder::new("A", 0, 1)
1077 .left_id(1)
1078 .right_id(1)
1079 .word_cost(100),
1080 );
1081
1082 let conn_cost = ZeroConnectionCost;
1083 let searcher = ViterbiSearcher::new();
1084
1085 let path = searcher.search(&mut lattice, &conn_cost);
1086
1087 assert!(!searcher.has_valid_path(&lattice));
1089 assert!(path.is_empty());
1090 }
1091
1092 #[test]
1093 fn test_nbest_single() {
1094 let mut lattice = Lattice::new("AB");
1095
1096 lattice.add_node(
1097 NodeBuilder::new("A", 0, 1)
1098 .left_id(1)
1099 .right_id(1)
1100 .word_cost(100),
1101 );
1102 lattice.add_node(
1103 NodeBuilder::new("B", 1, 2)
1104 .left_id(2)
1105 .right_id(2)
1106 .word_cost(200),
1107 );
1108
1109 let conn_cost = ZeroConnectionCost;
1110 let searcher = NbestSearcher::new(1);
1111
1112 let results = searcher.search(&mut lattice, &conn_cost);
1113
1114 assert_eq!(results.len(), 1);
1115 assert_eq!(results[0].1, 300); }
1117
1118 #[test]
1119 fn test_viterbi_result_helper() {
1120 let mut lattice = Lattice::new("AB");
1121
1122 let _id1 = lattice.add_node(
1123 NodeBuilder::new("A", 0, 1)
1124 .left_id(1)
1125 .right_id(1)
1126 .word_cost(100),
1127 );
1128 let _id2 = lattice.add_node(
1129 NodeBuilder::new("B", 1, 2)
1130 .left_id(2)
1131 .right_id(2)
1132 .word_cost(200),
1133 );
1134
1135 let conn_cost = ZeroConnectionCost;
1136 let searcher = ViterbiSearcher::new();
1137 let path = searcher.search(&mut lattice, &conn_cost);
1138 let cost = searcher.get_best_cost(&lattice);
1139
1140 let result = ViterbiResult::new(&lattice, path, cost);
1141
1142 assert_eq!(result.len(), 2);
1143 assert_eq!(result.cost(), 300);
1144 assert_eq!(result.surfaces(), vec!["A", "B"]);
1145 }
1146
1147 #[test]
1148 fn test_viterbi_with_dense_matrix() {
1149 use mecab_ko_dict::DenseMatrix;
1150
1151 let mut matrix = DenseMatrix::new(3, 3, 0);
1155
1156 matrix.set(0, 1, 100);
1159 matrix.set(1, 2, 50);
1161 matrix.set(2, 0, 30);
1163
1164 matrix.set(0, 2, 5000);
1166 matrix.set(1, 0, 200);
1168
1169 let mut lattice = Lattice::new("책을");
1170
1171 lattice.add_node(
1173 NodeBuilder::new("책", 0, 1)
1174 .left_id(1) .right_id(1) .word_cost(500),
1177 );
1178
1179 lattice.add_node(
1181 NodeBuilder::new("을", 1, 2)
1182 .left_id(2) .right_id(2) .word_cost(100),
1185 );
1186
1187 let searcher = ViterbiSearcher::new();
1188 let path = searcher.search(&mut lattice, &matrix);
1189
1190 assert!(!path.is_empty());
1192
1193 let result = ViterbiResult::new(&lattice, path, searcher.get_best_cost(&lattice));
1194 assert_eq!(result.surfaces(), vec!["책", "을"]);
1195
1196 assert_eq!(result.cost(), 780);
1199 }
1200}