Tutorials Code: Source Code for Tutorial 7
From DSL
#include <iostream>
#include <sstream>
#include <fstream>
#include "smile.h"
#include "smilearn.h"
using namespace std;
void learnEMDynamic() {
// open the data set:
DSL_dataset ds;
if (ds.ReadFile("ds_tut_6.txt") != DSL_OKAY) {
cout << "Cannot read data file... exiting." << endl;
exit(1);
}
// open the network:
DSL_network net;
if (net.ReadFile("net_tut_6.xdsl", DSL_XDSL_FORMAT) != DSL_OKAY) {
cout << "Cannot read network... exiting." << endl;
exit(1);
}
// match the data set and the network (variables):
vector<DSL_datasetMatch> dsMap(ds.GetNumberOfVariables());
int varCnt = 0; // the number of variables occuring both in the data set and the network
for (int i = 0; i < ds.GetNumberOfVariables(); i++) {
string id = ds.GetId(i);
const char* idc = id.c_str();
bool done = false;
for (int j = 0; j < (int) strlen(idc) && !done; j++) {
if (idc[j] == '_') {
char* nodeId = (char*) malloc((j+1) * sizeof(char));
strncpy(nodeId, idc, j);
nodeId[j] = '\0';
int nodeHdl = net.FindNode(nodeId);
if (nodeHdl >= 0) {
DSL_intArray orders;
net.GetTemporalOrders(nodeHdl, orders);
dsMap[varCnt].node = nodeHdl;
dsMap[varCnt].slice = atoi(idc + j + 1);
dsMap[varCnt].column = i;
varCnt++;
free(nodeId);
done = true;
}
}
}
if (!done) {
int nodeHdl = net.FindNode(idc);
if (nodeHdl >= 0) {
dsMap[varCnt].node = nodeHdl;
dsMap[varCnt].slice = 0;
dsMap[varCnt].column = i;
varCnt++;
}
}
}
dsMap.resize(varCnt);
// match the data set and the network (states):
for (int i = 0; i < dsMap.size(); i++) {
DSL_datasetMatch &m = dsMap[i];
int nodeHdl = m.node;
int colIdx = m.column;
DSL_idArray* ids = net.GetNode(nodeHdl)->Definition()->GetOutcomesNames();
const DSL_datasetVarInfo &varInfo = ds.GetVariableInfo(colIdx);
const vector<string> &stateNames = varInfo.stateNames;
vector<int> map(stateNames.size(), -1);
for (int j = 0; j < (int) stateNames.size(); j++) {
const char* id = stateNames[j].c_str();
for (int k = 0; k < ids->NumItems(); k++) {
char* tmpid = (*ids)[k];
if (!strcmp(id, tmpid)) {
map[j] = k;
}
}
}
for (int k = 0; k < ds.GetNumberOfRecords(); k++) {
if (ds.GetInt(colIdx, k) >= 0) {
ds.SetInt(colIdx, k, map[ds.GetInt(colIdx, k)]);
}
}
}
// learn parameters:
DSL_em em;
if (em.Learn(ds, net, dsMap) != DSL_OKAY) {
cout << "Cannot learn parameters... exiting." << endl;
exit(1);
}
net.WriteFile("res_tut_7.xdsl", DSL_XDSL_FORMAT);
}
int main(int argc, char* const argv[]) {
learnEMDynamic();
}
/*
ds_tut_6.txt
A A_0 A_1 B B_0 B_1 C C_0 C_1
t t t t t t t t t
t f f t f f t f f
f t f f t f f t f
f f f f f f f f f
t f t t f t t f t
t f f t f f t f f
f t f f t f f t f
f t t f t t f t t
t t t t t t t t t
t t t t t t t t t
t t t t t t t t t
t t f t t f t t f
f f f f f f f f f
f f t f f t f f t
f f f f f f f f f
t f f t f f t f f
f t f f t f f t f
t t t t t t t t t
t f f t f f t f f
f t f f t f f t f
*/