Tutorials Code: Source Code for Tutorial 7

From DSL
Jump to: navigation, search
  #include <iostream>
  #include <sstream>
  #include <fstream>
  
  #include "smile.h"
  #include "smilearn.h"
  
  using namespace std;
  
  
  void learnEMDynamic() {
     // open the data set:
     DSL_dataset ds;
     if (ds.ReadFile("ds_tut_6.txt") != DSL_OKAY) {
        cout << "Cannot read data file... exiting." << endl;
        exit(1);
     }
     
     // open the network:
     DSL_network net;
     if (net.ReadFile("net_tut_6.xdsl", DSL_XDSL_FORMAT) != DSL_OKAY) {
        cout << "Cannot read network... exiting." << endl;
        exit(1);
     }
     
     // match the data set and the network (variables):
     vector<DSL_datasetMatch> dsMap(ds.GetNumberOfVariables());
     int varCnt = 0;  // the number of variables occuring both in the data set and the network
     for (int i = 0; i < ds.GetNumberOfVariables(); i++) {
        string id = ds.GetId(i);
        const char* idc = id.c_str();
        
        bool done = false;
        for (int j = 0; j < (int) strlen(idc) && !done; j++) {
           if (idc[j] == '_') {
              char* nodeId = (char*) malloc((j+1) * sizeof(char));
              strncpy(nodeId, idc, j);
              nodeId[j] = '\0';
              
              int nodeHdl = net.FindNode(nodeId);
              if (nodeHdl >= 0) {
                 DSL_intArray orders;
                 net.GetTemporalOrders(nodeHdl, orders);
                 
                 dsMap[varCnt].node   = nodeHdl;
                 dsMap[varCnt].slice  = atoi(idc + j + 1);
                 dsMap[varCnt].column = i;
                 varCnt++;
                 
                 free(nodeId);
                 done = true;
              }
           }
        }
        if (!done) {
           int nodeHdl = net.FindNode(idc);
           if (nodeHdl >= 0) {
              dsMap[varCnt].node   = nodeHdl;
              dsMap[varCnt].slice  = 0;
              dsMap[varCnt].column = i;
              varCnt++;
           }
        }
     }
     dsMap.resize(varCnt);
     
     // match the data set and the network (states):
     for (int i = 0; i < dsMap.size(); i++) {
        DSL_datasetMatch &m = dsMap[i];
        int nodeHdl = m.node;
        int colIdx = m.column;
        
        DSL_idArray* ids = net.GetNode(nodeHdl)->Definition()->GetOutcomesNames();
        const DSL_datasetVarInfo &varInfo = ds.GetVariableInfo(colIdx);
        const vector<string> &stateNames = varInfo.stateNames;
        vector<int> map(stateNames.size(), -1);
        for (int j = 0; j < (int) stateNames.size(); j++) {
           const char* id = stateNames[j].c_str();
           for (int k = 0; k < ids->NumItems(); k++) {
              char* tmpid = (*ids)[k];
              if (!strcmp(id, tmpid)) {
                 map[j] = k;
              }
           }
        }
        for (int k = 0; k < ds.GetNumberOfRecords(); k++) {
           if (ds.GetInt(colIdx, k) >= 0) {
              ds.SetInt(colIdx, k, map[ds.GetInt(colIdx, k)]);
           }
        }
     }
     
     // learn parameters:
     DSL_em em;
     if (em.Learn(ds, net, dsMap) != DSL_OKAY) {
        cout << "Cannot learn parameters... exiting." << endl;
        exit(1);
     }
     net.WriteFile("res_tut_7.xdsl", DSL_XDSL_FORMAT);
  }
  
  
  int main(int argc, char* const argv[]) {
    learnEMDynamic();
  }
  
  
  /*
  ds_tut_6.txt
  
  A A_0 A_1 B B_0 B_1 C C_0 C_1
  t t t t t t t t t
  t f f t f f t f f
  f t f f t f f t f
  f f f f f f f f f
  t f t t f t t f t
  t f f t f f t f f
  f t f f t f f t f
  f t t f t t f t t
  t t t t t t t t t
  t t t t t t t t t
  t t t t t t t t t
  t t f t t f t t f
  f f f f f f f f f
  f f t f f t f f t
  f f f f f f f f f
  t f f t f f t f f
  f t f f t f f t f
  t t t t t t t t t
  t f f t f f t f f
  f t f f t f f t f
  */
Personal tools