StochHMM  v0.34
Flexible Hidden Markov Model C++ Library and Application
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
baum_welch.cpp
Go to the documentation of this file.
1 //
2 // baum_welch.cpp
3 // StochHMM
4 //
5 // Created by Paul Lott on 2/4/13.
6 // Copyright (c) 2013 Korf Lab, Genome Center, UC Davis, Davis, CA. All rights reserved.
7 //
8 
9 #include "trellis.h"
10 
11 namespace StochHMM {
12 
13  //TODO: Need to implement functions to allow Baum-Welch to update the model.
14 
15 
17 
18  }
19 
20 
22  if (dbl_forward_score == NULL){
23  naive_forward();
24  }
25 
26  if (dbl_backward_score == NULL){
28  }
29 
30  dbl_baum_welch_score = new (std::nothrow) double_3D(seq_size, std::vector<std::vector<double> >(state_size, std::vector<double>(state_size, -INFINITY)));
31 
32  if (dbl_baum_welch_score == NULL){
33  std::cerr << "Can't allocate memory. OUT OF MEMORY\t" << __FUNCTION__ << std::endl;
34  exit(2);
35  }
36 
37 
38  double sum(-INFINITY);
39 
40  for(size_t position = 0; position < seq_size-1; position++){ // Time(t)
41  sum = (-INFINITY);
42  for (size_t previous = 0; previous < state_size ; previous++){ // state(i)
43  for (size_t current = 0; current < state_size; current++){ // state(j)
44  (*dbl_baum_welch_score)[position][previous][current] = (*dbl_forward_score)[position][previous] + getTransition(hmm->getState(previous), current, position) + (*hmm)[current]->get_emission_prob(*seqs, position+1) + (*dbl_backward_score)[position+1][current];
45  sum = addLog((*dbl_baum_welch_score)[position][previous][current], sum);
46  }
47  }
48  for (size_t previous = 0; previous < state_size ; previous++){ // state(i)
49  for (size_t current = 0; current < state_size; current++){ // state(j)
50  (*dbl_baum_welch_score)[position][previous][current] -= sum;
51  }
52  }
53  }
54  return;
55  }
56 
57 
58 
60 
61  std::cout << "Transitions to Start:\n";
62  double updated(-INFINITY);
63  for(size_t st = 0 ; st < state_size ; st++){
64  updated = ((*dbl_backward_score)[0][st] + (*dbl_forward_score)[0][st])-ending_forward_prob;
65  std::cout << hmm->getStateName(st) << "\t" << exp(updated) << std::endl;
66  }
67 
68  float_2D numerator(state_size, std::vector<float>(state_size,-INFINITY));
69  float_2D denominator(state_size, std::vector<float>(state_size,-INFINITY));
70 
71  for (size_t position = 0; position < seq_size-1; position++){
72  for (size_t i = 0; i < state_size; i++){
73  for (size_t j = 0; j < state_size; j++){
74  numerator[i][j]= addLog((double)numerator[i][j], (*dbl_baum_welch_score)[position][i][j]);
75  denominator[i][j] = addLog((double)denominator[i][j], ((*dbl_forward_score)[position][i] + (*dbl_backward_score)[position][i])-ending_forward_prob);
76  }
77  }
78  }
79 
80  for (size_t i = 0; i < state_size; i++){
81  std::cout << hmm->getStateName(i) << "\t";
82  for (size_t j = 0; j < state_size; j++){
83  std::cout << exp((numerator[i][j]-denominator[i][j])) << "\t";
84  }
85  std::cout << std::endl;
86  }
87  }
88 
90 
91  }
92 
93 
94 
95 }