1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
|
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "model.h"
#include "loss.h"
#include "utils.h"
#include <algorithm>
#include <stdexcept>
namespace fasttext {
Model::State::State(int32_t hiddenSize, int32_t outputSize, int32_t seed)
: lossValue_(0.0),
nexamples_(0),
hidden(hiddenSize),
output(outputSize),
grad(hiddenSize),
rng(seed) {}
real Model::State::getLoss() const {
return lossValue_ / nexamples_;
}
void Model::State::incrementNExamples(real loss) {
lossValue_ += loss;
nexamples_++;
}
Model::Model(
std::shared_ptr<Matrix> wi,
std::shared_ptr<Matrix> wo,
std::shared_ptr<Loss> loss,
bool normalizeGradient)
: wi_(wi), wo_(wo), loss_(loss), normalizeGradient_(normalizeGradient) {}
void Model::computeHidden(const std::vector<int32_t>& input, State& state)
const {
Vector& hidden = state.hidden;
hidden.zero();
for (auto it = input.cbegin(); it != input.cend(); ++it) {
hidden.addRow(*wi_, *it);
}
hidden.mul(1.0 / input.size());
}
void Model::predict(
const std::vector<int32_t>& input,
int32_t k,
real threshold,
Predictions& heap,
State& state) const {
if (k == Model::kUnlimitedPredictions) {
k = wo_->size(0); // output size
} else if (k <= 0) {
throw std::invalid_argument("k needs to be 1 or higher!");
}
heap.reserve(k + 1);
computeHidden(input, state);
loss_->predict(k, threshold, heap, state);
}
void Model::update(
const std::vector<int32_t>& input,
const std::vector<int32_t>& targets,
int32_t targetIndex,
real lr,
State& state) {
if (input.size() == 0) {
return;
}
computeHidden(input, state);
Vector& grad = state.grad;
grad.zero();
real lossValue = loss_->forward(targets, targetIndex, state, lr, true);
state.incrementNExamples(lossValue);
if (normalizeGradient_) {
grad.mul(1.0 / input.size());
}
for (auto it = input.cbegin(); it != input.cend(); ++it) {
wi_->addVectorToRow(grad, *it, 1.0);
}
}
real Model::std_log(real x) const {
return std::log(x + 1e-5);
}
} // namespace fasttext
|