1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
/**
* Copyright (c) 2016-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <memory>
#include <random>
#include <vector>
#include "matrix.h"
#include "model.h"
#include "real.h"
#include "utils.h"
#include "vector.h"
namespace fasttext {
class Loss {
private:
void findKBest(
int32_t k,
real threshold,
Predictions& heap,
const Vector& output) const;
protected:
std::vector<real> t_sigmoid_;
std::vector<real> t_log_;
std::shared_ptr<Matrix>& wo_;
real log(real x) const;
real sigmoid(real x) const;
public:
explicit Loss(std::shared_ptr<Matrix>& wo);
virtual ~Loss() = default;
virtual real forward(
const std::vector<int32_t>& targets,
int32_t targetIndex,
Model::State& state,
real lr,
bool backprop) = 0;
virtual void computeOutput(Model::State& state) const = 0;
virtual void predict(
int32_t /*k*/,
real /*threshold*/,
Predictions& /*heap*/,
Model::State& /*state*/) const;
};
class BinaryLogisticLoss : public Loss {
protected:
real binaryLogistic(
int32_t target,
Model::State& state,
bool labelIsPositive,
real lr,
bool backprop) const;
public:
explicit BinaryLogisticLoss(std::shared_ptr<Matrix>& wo);
virtual ~BinaryLogisticLoss() noexcept override = default;
void computeOutput(Model::State& state) const override;
};
class OneVsAllLoss : public BinaryLogisticLoss {
public:
explicit OneVsAllLoss(std::shared_ptr<Matrix>& wo);
~OneVsAllLoss() noexcept override = default;
real forward(
const std::vector<int32_t>& targets,
int32_t targetIndex,
Model::State& state,
real lr,
bool backprop) override;
};
class NegativeSamplingLoss : public BinaryLogisticLoss {
protected:
static const int32_t NEGATIVE_TABLE_SIZE = 10000000;
int neg_;
std::vector<int32_t> negatives_;
std::uniform_int_distribution<size_t> uniform_;
int32_t getNegative(int32_t target, std::minstd_rand& rng);
public:
explicit NegativeSamplingLoss(
std::shared_ptr<Matrix>& wo,
int neg,
const std::vector<int64_t>& targetCounts);
~NegativeSamplingLoss() noexcept override = default;
real forward(
const std::vector<int32_t>& targets,
int32_t targetIndex,
Model::State& state,
real lr,
bool backprop) override;
};
class HierarchicalSoftmaxLoss : public BinaryLogisticLoss {
protected:
struct Node {
int32_t parent;
int32_t left;
int32_t right;
int64_t count;
bool binary;
};
std::vector<std::vector<int32_t>> paths_;
std::vector<std::vector<bool>> codes_;
std::vector<Node> tree_;
int32_t osz_;
void buildTree(const std::vector<int64_t>& counts);
void dfs(
int32_t k,
real threshold,
int32_t node,
real score,
Predictions& heap,
const Vector& hidden) const;
public:
explicit HierarchicalSoftmaxLoss(
std::shared_ptr<Matrix>& wo,
const std::vector<int64_t>& counts);
~HierarchicalSoftmaxLoss() noexcept override = default;
real forward(
const std::vector<int32_t>& targets,
int32_t targetIndex,
Model::State& state,
real lr,
bool backprop) override;
void predict(
int32_t k,
real threshold,
Predictions& heap,
Model::State& state) const override;
};
class SoftmaxLoss : public Loss {
public:
explicit SoftmaxLoss(std::shared_ptr<Matrix>& wo);
~SoftmaxLoss() noexcept override = default;
real forward(
const std::vector<int32_t>& targets,
int32_t targetIndex,
Model::State& state,
real lr,
bool backprop) override;
void computeOutput(Model::State& state) const override;
};
} // namespace fasttext
|