forked from moreo/QuaPy
documenting quanet
This commit is contained in:
parent
9cf9c73824
commit
164f7d8d5c
|
@ -11,6 +11,10 @@ used for evaluating quantification methods.
|
||||||
QuaPy also integrates commonly used datasets and offers visualization tools
|
QuaPy also integrates commonly used datasets and offers visualization tools
|
||||||
for facilitating the analysis and interpretation of results.
|
for facilitating the analysis and interpretation of results.
|
||||||
|
|
||||||
|
### Last updates:
|
||||||
|
|
||||||
|
* A detailed developer API documentation is now available [here]()!
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
```commandline
|
```commandline
|
||||||
|
|
|
@ -292,8 +292,6 @@
|
||||||
<li><a href="quapy.method.html#quapy.method.meta.ensembleFactory">ensembleFactory() (in module quapy.method.meta)</a>
|
<li><a href="quapy.method.html#quapy.method.meta.ensembleFactory">ensembleFactory() (in module quapy.method.meta)</a>
|
||||||
</li>
|
</li>
|
||||||
<li><a href="quapy.method.html#quapy.method.meta.EPACC">EPACC() (in module quapy.method.meta)</a>
|
<li><a href="quapy.method.html#quapy.method.meta.EPACC">EPACC() (in module quapy.method.meta)</a>
|
||||||
</li>
|
|
||||||
<li><a href="quapy.method.html#quapy.method.neural.QuaNetTrainer.epoch">epoch() (quapy.method.neural.QuaNetTrainer method)</a>
|
|
||||||
</li>
|
</li>
|
||||||
<li><a href="quapy.method.html#quapy.method.aggregative.EMQ.EPSILON">EPSILON (quapy.method.aggregative.EMQ attribute)</a>
|
<li><a href="quapy.method.html#quapy.method.aggregative.EMQ.EPSILON">EPSILON (quapy.method.aggregative.EMQ attribute)</a>
|
||||||
</li>
|
</li>
|
||||||
|
@ -390,8 +388,6 @@
|
||||||
<li><a href="quapy.html#quapy.evaluation.gen_prevalence_prediction">gen_prevalence_prediction() (in module quapy.evaluation)</a>
|
<li><a href="quapy.html#quapy.evaluation.gen_prevalence_prediction">gen_prevalence_prediction() (in module quapy.evaluation)</a>
|
||||||
</li>
|
</li>
|
||||||
<li><a href="quapy.html#quapy.evaluation.gen_prevalence_report">gen_prevalence_report() (in module quapy.evaluation)</a>
|
<li><a href="quapy.html#quapy.evaluation.gen_prevalence_report">gen_prevalence_report() (in module quapy.evaluation)</a>
|
||||||
</li>
|
|
||||||
<li><a href="quapy.method.html#quapy.method.neural.QuaNetTrainer.get_aggregative_estims">get_aggregative_estims() (quapy.method.neural.QuaNetTrainer method)</a>
|
|
||||||
</li>
|
</li>
|
||||||
<li><a href="quapy.html#quapy.functional.get_nprevpoints_approximation">get_nprevpoints_approximation() (in module quapy.functional)</a>
|
<li><a href="quapy.html#quapy.functional.get_nprevpoints_approximation">get_nprevpoints_approximation() (in module quapy.functional)</a>
|
||||||
</li>
|
</li>
|
||||||
|
@ -452,8 +448,6 @@
|
||||||
<li><a href="quapy.data.html#quapy.data.preprocessing.index">index() (in module quapy.data.preprocessing)</a>
|
<li><a href="quapy.data.html#quapy.data.preprocessing.index">index() (in module quapy.data.preprocessing)</a>
|
||||||
</li>
|
</li>
|
||||||
<li><a href="quapy.data.html#quapy.data.preprocessing.IndexTransformer">IndexTransformer (class in quapy.data.preprocessing)</a>
|
<li><a href="quapy.data.html#quapy.data.preprocessing.IndexTransformer">IndexTransformer (class in quapy.data.preprocessing)</a>
|
||||||
</li>
|
|
||||||
<li><a href="quapy.method.html#quapy.method.neural.QuaNetModule.init_hidden">init_hidden() (quapy.method.neural.QuaNetModule method)</a>
|
|
||||||
</li>
|
</li>
|
||||||
<li><a href="quapy.method.html#quapy.method.base.isaggregative">isaggregative() (in module quapy.method.base)</a>
|
<li><a href="quapy.method.html#quapy.method.base.isaggregative">isaggregative() (in module quapy.method.base)</a>
|
||||||
</li>
|
</li>
|
||||||
|
|
Binary file not shown.
|
@ -1756,6 +1756,24 @@ in terms of this error.</p>
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetModule">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetModule">
|
||||||
<em class="property"><span class="pre">class</span> </em><span class="sig-prename descclassname"><span class="pre">quapy.method.neural.</span></span><span class="sig-name descname"><span class="pre">QuaNetModule</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">doc_embedding_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_classes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">stats_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lstm_hidden_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">64</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lstm_nlayers</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ff_layers</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[1024,</span> <span class="pre">512]</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bidirectional</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qdrop_p</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">order_by</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetModule" title="Permalink to this definition">¶</a></dt>
|
<em class="property"><span class="pre">class</span> </em><span class="sig-prename descclassname"><span class="pre">quapy.method.neural.</span></span><span class="sig-name descname"><span class="pre">QuaNetModule</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">doc_embedding_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_classes</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">stats_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lstm_hidden_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">64</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lstm_nlayers</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ff_layers</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[1024,</span> <span class="pre">512]</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bidirectional</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qdrop_p</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">order_by</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetModule" title="Permalink to this definition">¶</a></dt>
|
||||||
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
|
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.modules.module.Module</span></code></p>
|
||||||
|
<p>Implements the <a class="reference external" href="https://dl.acm.org/doi/abs/10.1145/3269206.3269287">QuaNet</a> forward pass.
|
||||||
|
See <a class="reference internal" href="#quapy.method.neural.QuaNetTrainer" title="quapy.method.neural.QuaNetTrainer"><code class="xref py py-class docutils literal notranslate"><span class="pre">QuaNetTrainer</span></code></a> for training QuaNet.</p>
|
||||||
|
<dl class="field-list simple">
|
||||||
|
<dt class="field-odd">Parameters</dt>
|
||||||
|
<dd class="field-odd"><ul class="simple">
|
||||||
|
<li><p><strong>doc_embedding_size</strong> – integer, the dimensionality of the document embeddings</p></li>
|
||||||
|
<li><p><strong>n_classes</strong> – integer, number of classes</p></li>
|
||||||
|
<li><p><strong>stats_size</strong> – integer, number of statistics estimated by simple quantification methods</p></li>
|
||||||
|
<li><p><strong>lstm_hidden_size</strong> – integer, hidden dimensionality of the LSTM cell</p></li>
|
||||||
|
<li><p><strong>lstm_nlayers</strong> – integer, number of LSTM layers</p></li>
|
||||||
|
<li><p><strong>ff_layers</strong> – list of integers, dimensions of the densely-connected FF layers on top of the
|
||||||
|
quantification embedding</p></li>
|
||||||
|
<li><p><strong>bidirectional</strong> – boolean, whether or not to use bidirectional LSTM</p></li>
|
||||||
|
<li><p><strong>qdrop_p</strong> – float, dropout probability</p></li>
|
||||||
|
<li><p><strong>order_by</strong> – integer, class for which the document embeddings are to be sorted</p></li>
|
||||||
|
</ul>
|
||||||
|
</dd>
|
||||||
|
</dl>
|
||||||
<dl class="py property">
|
<dl class="py property">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetModule.device">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetModule.device">
|
||||||
<em class="property"><span class="pre">property</span> </em><span class="sig-name descname"><span class="pre">device</span></span><a class="headerlink" href="#quapy.method.neural.QuaNetModule.device" title="Permalink to this definition">¶</a></dt>
|
<em class="property"><span class="pre">property</span> </em><span class="sig-name descname"><span class="pre">device</span></span><a class="headerlink" href="#quapy.method.neural.QuaNetModule.device" title="Permalink to this definition">¶</a></dt>
|
||||||
|
@ -1775,17 +1793,62 @@ registered hooks while the latter silently ignores them.</p>
|
||||||
</div>
|
</div>
|
||||||
</dd></dl>
|
</dd></dl>
|
||||||
|
|
||||||
<dl class="py method">
|
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetModule.init_hidden">
|
|
||||||
<span class="sig-name descname"><span class="pre">init_hidden</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetModule.init_hidden" title="Permalink to this definition">¶</a></dt>
|
|
||||||
<dd></dd></dl>
|
|
||||||
|
|
||||||
</dd></dl>
|
</dd></dl>
|
||||||
|
|
||||||
<dl class="py class">
|
<dl class="py class">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer">
|
||||||
<em class="property"><span class="pre">class</span> </em><span class="sig-prename descclassname"><span class="pre">quapy.method.neural.</span></span><span class="sig-name descname"><span class="pre">QuaNetTrainer</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">learner</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sample_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_epochs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tr_iter_per_poch</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">500</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">va_iter_per_poch</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lr</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lstm_hidden_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">64</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lstm_nlayers</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ff_layers</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[1024,</span> <span class="pre">512]</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bidirectional</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qdrop_p</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">patience</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">10</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">checkpointdir</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'../checkpoint'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">checkpointname</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'cuda'</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer" title="Permalink to this definition">¶</a></dt>
|
<em class="property"><span class="pre">class</span> </em><span class="sig-prename descclassname"><span class="pre">quapy.method.neural.</span></span><span class="sig-name descname"><span class="pre">QuaNetTrainer</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">learner</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sample_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_epochs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">tr_iter_per_poch</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">500</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">va_iter_per_poch</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">100</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lr</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lstm_hidden_size</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">64</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lstm_nlayers</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ff_layers</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">[1024,</span> <span class="pre">512]</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bidirectional</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qdrop_p</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">patience</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">10</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">checkpointdir</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'../checkpoint'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">checkpointname</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'cuda'</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer" title="Permalink to this definition">¶</a></dt>
|
||||||
<dd><p>Bases: <a class="reference internal" href="#quapy.method.base.BaseQuantifier" title="quapy.method.base.BaseQuantifier"><code class="xref py py-class docutils literal notranslate"><span class="pre">quapy.method.base.BaseQuantifier</span></code></a></p>
|
<dd><p>Bases: <a class="reference internal" href="#quapy.method.base.BaseQuantifier" title="quapy.method.base.BaseQuantifier"><code class="xref py py-class docutils literal notranslate"><span class="pre">quapy.method.base.BaseQuantifier</span></code></a></p>
|
||||||
|
<p>Implementation of <a class="reference external" href="https://dl.acm.org/doi/abs/10.1145/3269206.3269287">QuaNet</a>, a neural network for
|
||||||
|
quantification. This implementation uses <a class="reference external" href="https://pytorch.org/">PyTorch</a> and can take advantage of GPU
|
||||||
|
for speeding-up the training phase.</p>
|
||||||
|
<p>Example:</p>
|
||||||
|
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">quapy</span> <span class="k">as</span> <span class="nn">qp</span>
|
||||||
|
<span class="gp">>>> </span><span class="kn">from</span> <span class="nn">quapy.method.meta</span> <span class="kn">import</span> <span class="n">QuaNet</span>
|
||||||
|
<span class="gp">>>> </span><span class="kn">from</span> <span class="nn">quapy.classification.neural</span> <span class="kn">import</span> <span class="n">NeuralClassifierTrainer</span><span class="p">,</span> <span class="n">CNNnet</span>
|
||||||
|
<span class="go">>>></span>
|
||||||
|
<span class="gp">>>> </span><span class="c1"># use samples of 100 elements</span>
|
||||||
|
<span class="gp">>>> </span><span class="n">qp</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'SAMPLE_SIZE'</span><span class="p">]</span> <span class="o">=</span> <span class="mi">100</span>
|
||||||
|
<span class="go">>>></span>
|
||||||
|
<span class="gp">>>> </span><span class="c1"># load the kindle dataset as text, and convert words to numerical indexes</span>
|
||||||
|
<span class="gp">>>> </span><span class="n">dataset</span> <span class="o">=</span> <span class="n">qp</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">fetch_reviews</span><span class="p">(</span><span class="s1">'kindle'</span><span class="p">,</span> <span class="n">pickle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||||
|
<span class="gp">>>> </span><span class="n">qp</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">preprocessing</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">min_df</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||||
|
<span class="go">>>></span>
|
||||||
|
<span class="gp">>>> </span><span class="c1"># the text classifier is a CNN trained by NeuralClassifierTrainer</span>
|
||||||
|
<span class="gp">>>> </span><span class="n">cnn</span> <span class="o">=</span> <span class="n">CNNnet</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">vocabulary_size</span><span class="p">,</span> <span class="n">dataset</span><span class="o">.</span><span class="n">n_classes</span><span class="p">)</span>
|
||||||
|
<span class="gp">>>> </span><span class="n">learner</span> <span class="o">=</span> <span class="n">NeuralClassifierTrainer</span><span class="p">(</span><span class="n">cnn</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||||
|
<span class="go">>>></span>
|
||||||
|
<span class="gp">>>> </span><span class="c1"># train QuaNet (QuaNet is an alias to QuaNetTrainer)</span>
|
||||||
|
<span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">QuaNet</span><span class="p">(</span><span class="n">learner</span><span class="p">,</span> <span class="n">qp</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'SAMPLE_SIZE'</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||||
|
<span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">training</span><span class="p">)</span>
|
||||||
|
<span class="gp">>>> </span><span class="n">estim_prevalence</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">quantify</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">test</span><span class="o">.</span><span class="n">instances</span><span class="p">)</span>
|
||||||
|
</pre></div>
|
||||||
|
</div>
|
||||||
|
<dl class="field-list simple">
|
||||||
|
<dt class="field-odd">Parameters</dt>
|
||||||
|
<dd class="field-odd"><ul class="simple">
|
||||||
|
<li><p><strong>learner</strong> – an object implementing <cite>fit</cite> (i.e., that can be trained on labelled data),
|
||||||
|
<cite>predict_proba</cite> (i.e., that can generate posterior probabilities of unlabelled examples) and
|
||||||
|
<cite>transform</cite> (i.e., that can generate embedded representations of the unlabelled instances).</p></li>
|
||||||
|
<li><p><strong>sample_size</strong> – integer, the sample size</p></li>
|
||||||
|
<li><p><strong>n_epochs</strong> – integer, maximum number of training epochs</p></li>
|
||||||
|
<li><p><strong>tr_iter_per_poch</strong> – integer, number of training iterations before considering an epoch complete</p></li>
|
||||||
|
<li><p><strong>va_iter_per_poch</strong> – integer, number of validation iterations to perform after each epoch</p></li>
|
||||||
|
<li><p><strong>lr</strong> – float, the learning rate</p></li>
|
||||||
|
<li><p><strong>lstm_hidden_size</strong> – integer, hidden dimensionality of the LSTM cells</p></li>
|
||||||
|
<li><p><strong>lstm_nlayers</strong> – integer, number of LSTM layers</p></li>
|
||||||
|
<li><p><strong>ff_layers</strong> – list of integers, dimensions of the densely-connected FF layers on top of the
|
||||||
|
quantification embedding</p></li>
|
||||||
|
<li><p><strong>bidirectional</strong> – boolean, indicates whether the LSTM is bidirectional or not</p></li>
|
||||||
|
<li><p><strong>qdrop_p</strong> – float, dropout probability</p></li>
|
||||||
|
<li><p><strong>patience</strong> – integer, number of epochs showing no improvement in the validation set before stopping the
|
||||||
|
training phase (early stopping)</p></li>
|
||||||
|
<li><p><strong>checkpointdir</strong> – string, a path where to store models’ checkpoints</p></li>
|
||||||
|
<li><p><strong>checkpointname</strong> – string (optional), the name of the model’s checkpoint</p></li>
|
||||||
|
<li><p><strong>device</strong> – string, indicate “cpu” or “cuda”</p></li>
|
||||||
|
</ul>
|
||||||
|
</dd>
|
||||||
|
</dl>
|
||||||
<dl class="py property">
|
<dl class="py property">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.classes_">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.classes_">
|
||||||
<em class="property"><span class="pre">property</span> </em><span class="sig-name descname"><span class="pre">classes_</span></span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.classes_" title="Permalink to this definition">¶</a></dt>
|
<em class="property"><span class="pre">property</span> </em><span class="sig-name descname"><span class="pre">classes_</span></span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.classes_" title="Permalink to this definition">¶</a></dt>
|
||||||
|
@ -1800,17 +1863,14 @@ registered hooks while the latter silently ignores them.</p>
|
||||||
<dl class="py method">
|
<dl class="py method">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.clean_checkpoint">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.clean_checkpoint">
|
||||||
<span class="sig-name descname"><span class="pre">clean_checkpoint</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.clean_checkpoint" title="Permalink to this definition">¶</a></dt>
|
<span class="sig-name descname"><span class="pre">clean_checkpoint</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.clean_checkpoint" title="Permalink to this definition">¶</a></dt>
|
||||||
<dd></dd></dl>
|
<dd><p>Removes the checkpoint</p>
|
||||||
|
</dd></dl>
|
||||||
|
|
||||||
<dl class="py method">
|
<dl class="py method">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.clean_checkpoint_dir">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.clean_checkpoint_dir">
|
||||||
<span class="sig-name descname"><span class="pre">clean_checkpoint_dir</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.clean_checkpoint_dir" title="Permalink to this definition">¶</a></dt>
|
<span class="sig-name descname"><span class="pre">clean_checkpoint_dir</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.clean_checkpoint_dir" title="Permalink to this definition">¶</a></dt>
|
||||||
<dd></dd></dl>
|
<dd><p>Removes anything contained in the checkpoint directory</p>
|
||||||
|
</dd></dl>
|
||||||
<dl class="py method">
|
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.epoch">
|
|
||||||
<span class="sig-name descname"><span class="pre">epoch</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data</span></span><span class="p"><span class="pre">:</span></span> <span class="n"><a class="reference internal" href="quapy.data.html#quapy.data.base.LabelledCollection" title="quapy.data.base.LabelledCollection"><span class="pre">quapy.data.base.LabelledCollection</span></a></span></em>, <em class="sig-param"><span class="n"><span class="pre">posteriors</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">iterations</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">epoch</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">early_stop</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">train</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.epoch" title="Permalink to this definition">¶</a></dt>
|
|
||||||
<dd></dd></dl>
|
|
||||||
|
|
||||||
<dl class="py method">
|
<dl class="py method">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.fit">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.fit">
|
||||||
|
@ -1819,10 +1879,10 @@ registered hooks while the latter silently ignores them.</p>
|
||||||
<dl class="field-list simple">
|
<dl class="field-list simple">
|
||||||
<dt class="field-odd">Parameters</dt>
|
<dt class="field-odd">Parameters</dt>
|
||||||
<dd class="field-odd"><ul class="simple">
|
<dd class="field-odd"><ul class="simple">
|
||||||
<li><p><strong>data</strong> – the training data on which to train QuaNet. If fit_learner=True, the data will be split in
|
<li><p><strong>data</strong> – the training data on which to train QuaNet. If <cite>fit_learner=True</cite>, the data will be split in
|
||||||
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
|
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
|
||||||
fit_learner=False, the data will be split in 66/34 for training QuaNet and validating it, respectively.</p></li>
|
<cite>fit_learner=False</cite>, the data will be split in 66/34 for training QuaNet and validating it, respectively.</p></li>
|
||||||
<li><p><strong>fit_learner</strong> – if true, trains the classifier on a split containing 40% of the data</p></li>
|
<li><p><strong>fit_learner</strong> – if True, trains the classifier on a split containing 40% of the data</p></li>
|
||||||
</ul>
|
</ul>
|
||||||
</dd>
|
</dd>
|
||||||
<dt class="field-even">Returns</dt>
|
<dt class="field-even">Returns</dt>
|
||||||
|
@ -1831,11 +1891,6 @@ fit_learner=False, the data will be split in 66/34 for training QuaNet and valid
|
||||||
</dl>
|
</dl>
|
||||||
</dd></dl>
|
</dd></dl>
|
||||||
|
|
||||||
<dl class="py method">
|
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.get_aggregative_estims">
|
|
||||||
<span class="sig-name descname"><span class="pre">get_aggregative_estims</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">posteriors</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.get_aggregative_estims" title="Permalink to this definition">¶</a></dt>
|
|
||||||
<dd></dd></dl>
|
|
||||||
|
|
||||||
<dl class="py method">
|
<dl class="py method">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.get_params">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.get_params">
|
||||||
<span class="sig-name descname"><span class="pre">get_params</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">deep</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.get_params" title="Permalink to this definition">¶</a></dt>
|
<span class="sig-name descname"><span class="pre">get_params</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">deep</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.get_params" title="Permalink to this definition">¶</a></dt>
|
||||||
|
@ -1852,7 +1907,7 @@ fit_learner=False, the data will be split in 66/34 for training QuaNet and valid
|
||||||
|
|
||||||
<dl class="py method">
|
<dl class="py method">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.quantify">
|
<dt class="sig sig-object py" id="quapy.method.neural.QuaNetTrainer.quantify">
|
||||||
<span class="sig-name descname"><span class="pre">quantify</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">instances</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.quantify" title="Permalink to this definition">¶</a></dt>
|
<span class="sig-name descname"><span class="pre">quantify</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">instances</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.QuaNetTrainer.quantify" title="Permalink to this definition">¶</a></dt>
|
||||||
<dd><p>Generate class prevalence estimates for the sample’s instances</p>
|
<dd><p>Generate class prevalence estimates for the sample’s instances</p>
|
||||||
<dl class="field-list simple">
|
<dl class="field-list simple">
|
||||||
<dt class="field-odd">Parameters</dt>
|
<dt class="field-odd">Parameters</dt>
|
||||||
|
@ -1880,7 +1935,19 @@ fit_learner=False, the data will be split in 66/34 for training QuaNet and valid
|
||||||
<dl class="py function">
|
<dl class="py function">
|
||||||
<dt class="sig sig-object py" id="quapy.method.neural.mae_loss">
|
<dt class="sig sig-object py" id="quapy.method.neural.mae_loss">
|
||||||
<span class="sig-prename descclassname"><span class="pre">quapy.method.neural.</span></span><span class="sig-name descname"><span class="pre">mae_loss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.mae_loss" title="Permalink to this definition">¶</a></dt>
|
<span class="sig-prename descclassname"><span class="pre">quapy.method.neural.</span></span><span class="sig-name descname"><span class="pre">mae_loss</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">output</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">target</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#quapy.method.neural.mae_loss" title="Permalink to this definition">¶</a></dt>
|
||||||
<dd></dd></dl>
|
<dd><p>Torch-like wrapper for the Mean Absolute Error</p>
|
||||||
|
<dl class="field-list simple">
|
||||||
|
<dt class="field-odd">Parameters</dt>
|
||||||
|
<dd class="field-odd"><ul class="simple">
|
||||||
|
<li><p><strong>output</strong> – predictions</p></li>
|
||||||
|
<li><p><strong>target</strong> – ground truth values</p></li>
|
||||||
|
</ul>
|
||||||
|
</dd>
|
||||||
|
<dt class="field-even">Returns</dt>
|
||||||
|
<dd class="field-even"><p>mean absolute error loss</p>
|
||||||
|
</dd>
|
||||||
|
</dl>
|
||||||
|
</dd></dl>
|
||||||
|
|
||||||
</section>
|
</section>
|
||||||
<section id="module-quapy.method.non_aggregative">
|
<section id="module-quapy.method.non_aggregative">
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -11,6 +11,53 @@ from quapy.util import EarlyStop
|
||||||
|
|
||||||
|
|
||||||
class QuaNetTrainer(BaseQuantifier):
|
class QuaNetTrainer(BaseQuantifier):
|
||||||
|
"""
|
||||||
|
Implementation of `QuaNet <https://dl.acm.org/doi/abs/10.1145/3269206.3269287>`_, a neural network for
|
||||||
|
quantification. This implementation uses `PyTorch <https://pytorch.org/>`_ and can take advantage of GPU
|
||||||
|
for speeding-up the training phase.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> import quapy as qp
|
||||||
|
>>> from quapy.method.meta import QuaNet
|
||||||
|
>>> from quapy.classification.neural import NeuralClassifierTrainer, CNNnet
|
||||||
|
>>>
|
||||||
|
>>> # use samples of 100 elements
|
||||||
|
>>> qp.environ['SAMPLE_SIZE'] = 100
|
||||||
|
>>>
|
||||||
|
>>> # load the kindle dataset as text, and convert words to numerical indexes
|
||||||
|
>>> dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
|
||||||
|
>>> qp.data.preprocessing.index(dataset, min_df=5, inplace=True)
|
||||||
|
>>>
|
||||||
|
>>> # the text classifier is a CNN trained by NeuralClassifierTrainer
|
||||||
|
>>> cnn = CNNnet(dataset.vocabulary_size, dataset.n_classes)
|
||||||
|
>>> learner = NeuralClassifierTrainer(cnn, device='cuda')
|
||||||
|
>>>
|
||||||
|
>>> # train QuaNet (QuaNet is an alias to QuaNetTrainer)
|
||||||
|
>>> model = QuaNet(learner, qp.environ['SAMPLE_SIZE'], device='cuda')
|
||||||
|
>>> model.fit(dataset.training)
|
||||||
|
>>> estim_prevalence = model.quantify(dataset.test.instances)
|
||||||
|
|
||||||
|
:param learner: an object implementing `fit` (i.e., that can be trained on labelled data),
|
||||||
|
`predict_proba` (i.e., that can generate posterior probabilities of unlabelled examples) and
|
||||||
|
`transform` (i.e., that can generate embedded representations of the unlabelled instances).
|
||||||
|
:param sample_size: integer, the sample size
|
||||||
|
:param n_epochs: integer, maximum number of training epochs
|
||||||
|
:param tr_iter_per_poch: integer, number of training iterations before considering an epoch complete
|
||||||
|
:param va_iter_per_poch: integer, number of validation iterations to perform after each epoch
|
||||||
|
:param lr: float, the learning rate
|
||||||
|
:param lstm_hidden_size: integer, hidden dimensionality of the LSTM cells
|
||||||
|
:param lstm_nlayers: integer, number of LSTM layers
|
||||||
|
:param ff_layers: list of integers, dimensions of the densely-connected FF layers on top of the
|
||||||
|
quantification embedding
|
||||||
|
:param bidirectional: boolean, indicates whether the LSTM is bidirectional or not
|
||||||
|
:param qdrop_p: float, dropout probability
|
||||||
|
:param patience: integer, number of epochs showing no improvement in the validation set before stopping the
|
||||||
|
training phase (early stopping)
|
||||||
|
:param checkpointdir: string, a path where to store models' checkpoints
|
||||||
|
:param checkpointname: string (optional), the name of the model's checkpoint
|
||||||
|
:param device: string, indicate "cpu" or "cuda"
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
learner,
|
learner,
|
||||||
|
@ -28,6 +75,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
checkpointdir='../checkpoint',
|
checkpointdir='../checkpoint',
|
||||||
checkpointname=None,
|
checkpointname=None,
|
||||||
device='cuda'):
|
device='cuda'):
|
||||||
|
|
||||||
assert hasattr(learner, 'transform'), \
|
assert hasattr(learner, 'transform'), \
|
||||||
f'the learner {learner.__class__.__name__} does not seem to be able to produce document embeddings ' \
|
f'the learner {learner.__class__.__name__} does not seem to be able to produce document embeddings ' \
|
||||||
f'since it does not implement the method "transform"'
|
f'since it does not implement the method "transform"'
|
||||||
|
@ -64,10 +112,10 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
"""
|
"""
|
||||||
Trains QuaNet.
|
Trains QuaNet.
|
||||||
|
|
||||||
:param data: the training data on which to train QuaNet. If fit_learner=True, the data will be split in
|
:param data: the training data on which to train QuaNet. If `fit_learner=True`, the data will be split in
|
||||||
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
|
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
|
||||||
fit_learner=False, the data will be split in 66/34 for training QuaNet and validating it, respectively.
|
`fit_learner=False`, the data will be split in 66/34 for training QuaNet and validating it, respectively.
|
||||||
:param fit_learner: if true, trains the classifier on a split containing 40% of the data
|
:param fit_learner: if True, trains the classifier on a split containing 40% of the data
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
self._classes_ = data.classes_
|
self._classes_ = data.classes_
|
||||||
|
@ -125,8 +173,8 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
|
|
||||||
for epoch_i in range(1, self.n_epochs):
|
for epoch_i in range(1, self.n_epochs):
|
||||||
self.epoch(train_data_embed, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
|
self._epoch(train_data_embed, train_posteriors, self.tr_iter, epoch_i, early_stop, train=True)
|
||||||
self.epoch(valid_data_embed, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
|
self._epoch(valid_data_embed, valid_posteriors, self.va_iter, epoch_i, early_stop, train=False)
|
||||||
|
|
||||||
early_stop(self.status['va-loss'], epoch_i)
|
early_stop(self.status['va-loss'], epoch_i)
|
||||||
if early_stop.IMPROVED:
|
if early_stop.IMPROVED:
|
||||||
|
@ -139,7 +187,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_aggregative_estims(self, posteriors):
|
def _get_aggregative_estims(self, posteriors):
|
||||||
label_predictions = np.argmax(posteriors, axis=-1)
|
label_predictions = np.argmax(posteriors, axis=-1)
|
||||||
prevs_estim = []
|
prevs_estim = []
|
||||||
for quantifier in self.quantifiers.values():
|
for quantifier in self.quantifiers.values():
|
||||||
|
@ -150,10 +198,10 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
|
|
||||||
return prevs_estim
|
return prevs_estim
|
||||||
|
|
||||||
def quantify(self, instances, *args):
|
def quantify(self, instances):
|
||||||
posteriors = self.learner.predict_proba(instances)
|
posteriors = self.learner.predict_proba(instances)
|
||||||
embeddings = self.learner.transform(instances)
|
embeddings = self.learner.transform(instances)
|
||||||
quant_estims = self.get_aggregative_estims(posteriors)
|
quant_estims = self._get_aggregative_estims(posteriors)
|
||||||
self.quanet.eval()
|
self.quanet.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prevalence = self.quanet.forward(embeddings, posteriors, quant_estims)
|
prevalence = self.quanet.forward(embeddings, posteriors, quant_estims)
|
||||||
|
@ -162,7 +210,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
prevalence = prevalence.numpy().flatten()
|
prevalence = prevalence.numpy().flatten()
|
||||||
return prevalence
|
return prevalence
|
||||||
|
|
||||||
def epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
|
def _epoch(self, data: LabelledCollection, posteriors, iterations, epoch, early_stop, train):
|
||||||
mse_loss = MSELoss()
|
mse_loss = MSELoss()
|
||||||
|
|
||||||
self.quanet.train(mode=train)
|
self.quanet.train(mode=train)
|
||||||
|
@ -181,7 +229,7 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
for it, index in enumerate(pbar):
|
for it, index in enumerate(pbar):
|
||||||
sample_data = data.sampling_from_index(index)
|
sample_data = data.sampling_from_index(index)
|
||||||
sample_posteriors = posteriors[index]
|
sample_posteriors = posteriors[index]
|
||||||
quant_estims = self.get_aggregative_estims(sample_posteriors)
|
quant_estims = self._get_aggregative_estims(sample_posteriors)
|
||||||
ptrue = torch.as_tensor([sample_data.prevalence()], dtype=torch.float, device=self.device)
|
ptrue = torch.as_tensor([sample_data.prevalence()], dtype=torch.float, device=self.device)
|
||||||
if train:
|
if train:
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
@ -236,9 +284,15 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
f'the parameters of QuaNet or the learner {self.learner.__class__.__name__}')
|
f'the parameters of QuaNet or the learner {self.learner.__class__.__name__}')
|
||||||
|
|
||||||
def clean_checkpoint(self):
|
def clean_checkpoint(self):
|
||||||
|
"""
|
||||||
|
Removes the checkpoint
|
||||||
|
"""
|
||||||
os.remove(self.checkpoint)
|
os.remove(self.checkpoint)
|
||||||
|
|
||||||
def clean_checkpoint_dir(self):
|
def clean_checkpoint_dir(self):
|
||||||
|
"""
|
||||||
|
Removes anything contained in the checkpoint directory
|
||||||
|
"""
|
||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(self.checkpointdir, ignore_errors=True)
|
shutil.rmtree(self.checkpointdir, ignore_errors=True)
|
||||||
|
|
||||||
|
@ -248,10 +302,33 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
|
|
||||||
|
|
||||||
def mae_loss(output, target):
|
def mae_loss(output, target):
|
||||||
|
"""
|
||||||
|
Torch-like wrapper for the Mean Absolute Error
|
||||||
|
|
||||||
|
:param output: predictions
|
||||||
|
:param target: ground truth values
|
||||||
|
:return: mean absolute error loss
|
||||||
|
"""
|
||||||
return torch.mean(torch.abs(output - target))
|
return torch.mean(torch.abs(output - target))
|
||||||
|
|
||||||
|
|
||||||
class QuaNetModule(torch.nn.Module):
|
class QuaNetModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implements the `QuaNet <https://dl.acm.org/doi/abs/10.1145/3269206.3269287>`_ forward pass.
|
||||||
|
See :class:`QuaNetTrainer` for training QuaNet.
|
||||||
|
|
||||||
|
:param doc_embedding_size: integer, the dimensionality of the document embeddings
|
||||||
|
:param n_classes: integer, number of classes
|
||||||
|
:param stats_size: integer, number of statistics estimated by simple quantification methods
|
||||||
|
:param lstm_hidden_size: integer, hidden dimensionality of the LSTM cell
|
||||||
|
:param lstm_nlayers: integer, number of LSTM layers
|
||||||
|
:param ff_layers: list of integers, dimensions of the densely-connected FF layers on top of the
|
||||||
|
quantification embedding
|
||||||
|
:param bidirectional: boolean, whether or not to use bidirectional LSTM
|
||||||
|
:param qdrop_p: float, dropout probability
|
||||||
|
:param order_by: integer, class for which the document embeddings are to be sorted
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
doc_embedding_size,
|
doc_embedding_size,
|
||||||
n_classes,
|
n_classes,
|
||||||
|
@ -262,6 +339,7 @@ class QuaNetModule(torch.nn.Module):
|
||||||
bidirectional=True,
|
bidirectional=True,
|
||||||
qdrop_p=0.5,
|
qdrop_p=0.5,
|
||||||
order_by=0):
|
order_by=0):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
|
@ -289,7 +367,7 @@ class QuaNetModule(torch.nn.Module):
|
||||||
def device(self):
|
def device(self):
|
||||||
return torch.device('cuda') if next(self.parameters()).is_cuda else torch.device('cpu')
|
return torch.device('cuda') if next(self.parameters()).is_cuda else torch.device('cpu')
|
||||||
|
|
||||||
def init_hidden(self):
|
def _init_hidden(self):
|
||||||
directions = 2 if self.bidirectional else 1
|
directions = 2 if self.bidirectional else 1
|
||||||
var_hidden = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
|
var_hidden = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
|
||||||
var_cell = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
|
var_cell = torch.zeros(self.nlayers * directions, 1, self.hidden_size)
|
||||||
|
@ -315,7 +393,7 @@ class QuaNetModule(torch.nn.Module):
|
||||||
embeded_posteriors = embeded_posteriors.unsqueeze(0)
|
embeded_posteriors = embeded_posteriors.unsqueeze(0)
|
||||||
|
|
||||||
self.lstm.flatten_parameters()
|
self.lstm.flatten_parameters()
|
||||||
_, (rnn_hidden,_) = self.lstm(embeded_posteriors, self.init_hidden())
|
_, (rnn_hidden,_) = self.lstm(embeded_posteriors, self._init_hidden())
|
||||||
rnn_hidden = rnn_hidden.view(self.nlayers, self.ndirections, 1, self.hidden_size)
|
rnn_hidden = rnn_hidden.view(self.nlayers, self.ndirections, 1, self.hidden_size)
|
||||||
quant_embedding = rnn_hidden[0].view(-1)
|
quant_embedding = rnn_hidden[0].view(-1)
|
||||||
quant_embedding = torch.cat((quant_embedding, statistics))
|
quant_embedding = torch.cat((quant_embedding, statistics))
|
||||||
|
|
Loading…
Reference in New Issue