1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "matrix.h"
#include "real.h"
#include "utils.h"
#include "vector.h"
namespace fasttext {
class Loss;
class Model {
protected:
std::shared_ptr<Matrix> wi_;
std::shared_ptr<Matrix> wo_;
std::shared_ptr<Loss> loss_;
bool normalizeGradient_;
public:
Model(
std::shared_ptr<Matrix> wi,
std::shared_ptr<Matrix> wo,
std::shared_ptr<Loss> loss,
bool normalizeGradient);
Model(const Model& model) = delete;
Model(Model&& model) = delete;
Model& operator=(const Model& other) = delete;
Model& operator=(Model&& other) = delete;
class State {
private:
real lossValue_;
int64_t nexamples_;
public:
Vector hidden;
Vector output;
Vector grad;
std::minstd_rand rng;
State(int32_t hiddenSize, int32_t outputSize, int32_t seed);
real getLoss() const;
void incrementNExamples(real loss);
};
void predict(
const std::vector<int32_t>& input,
int32_t k,
real threshold,
Predictions& heap,
State& state) const;
void update(
const std::vector<int32_t>& input,
const std::vector<int32_t>& targets,
int32_t targetIndex,
real lr,
State& state);
void computeHidden(const std::vector<int32_t>& input, State& state) const;
real std_log(real) const;
static const int32_t kUnlimitedPredictions = -1;
static const int32_t kAllLabelsAsTarget = -1;
};
} // namespace fasttext
|