Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conflicts resolution #1

Closed
wants to merge 275 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
275 commits
Select commit Hold shift + click to select a range
6b2b7d1
factor mask
Apr 29, 2021
5225331
cachedShortLemmaEt_ works
Apr 29, 2021
d41353e
get ready for cachedShortWt_
Apr 29, 2021
e518fc9
cachedShortWt_ works
Apr 29, 2021
1784da0
start lsh
Apr 29, 2021
947301a
lsh runs but crap output
Apr 29, 2021
1a3e5ab
debug
Apr 29, 2021
daf853e
batch idx nearly there
Apr 29, 2021
1672201
use origBatchIdx
Apr 29, 2021
5be8249
virtual destructor
Apr 30, 2021
84d4987
warnings
Apr 30, 2021
1e62a16
warnings
Apr 30, 2021
b8153bb
warnings
Apr 30, 2021
560bdbd
warnings
Apr 30, 2021
82db7ab
start using only expr
Apr 30, 2021
86d7e30
getFactorMasksMultiDim
Apr 30, 2021
0777343
debug
Apr 30, 2021
ab2afff
old-style iter. For gcc 5
Apr 30, 2021
4500221
don't define BLAS_FOUND
Apr 30, 2021
7faebf7
use args
Apr 30, 2021
379212b
Enable compute86 where supported (#863)
XapaJIaMnu May 4, 2021
fe74576
Update VERSION
snukky May 4, 2021
8b818b7
Avoid Ampere misaligment issue
emjotde May 17, 2021
84a20f6
Merge branch 'master' into pmaster
emjotde May 24, 2021
3133a9b
resolve conflict
emjotde May 24, 2021
9fa166b
Online quantization (#847)
ykim362 May 25, 2021
6e87f16
Merged PR 18763: Fix adding new validation metrics with --valid-reset…
May 26, 2021
2c1b16f
Merged PR 19252: Update sentencepiece module to include CMake changes
rjai Jun 4, 2021
28e5e22
filter once for shortlist
Jun 4, 2021
f19ebba
debug
Jun 4, 2021
77c0cac
broadcasting bdot
emjotde Jun 7, 2021
2e6f029
add legacy bdot
emjotde Jun 7, 2021
1d96d7b
add legacy code on cpu
emjotde Jun 7, 2021
ce34df4
add legacy code on gpu
emjotde Jun 7, 2021
1c0b899
Merge branch 'pmaster'
emjotde Jun 7, 2021
bc4ad24
Merge branch 'mjd/bdot' into hihoan/lsh7
Jun 7, 2021
0949a4c
start using bdot
Jun 7, 2021
acdff77
reduce tranform for no-shortlist
Jun 7, 2021
b5f97dc
reshape cachedShortLemmaEt
Jun 7, 2021
eb3f540
debug
Jun 7, 2021
92c6c07
reshape cachedShortWt_
Jun 7, 2021
e07e036
debug
Jun 9, 2021
5d1946e
filter & broadcast every word. SL works
Jun 9, 2021
0bc9b22
separate broadcast
Jun 9, 2021
79dbde7
don't manually broadcast weights
Jun 9, 2021
4b9082b
don't manually broadcast lemma
Jun 9, 2021
1e3db86
batch based filtering. COmment out debug
Jun 9, 2021
6f0f534
debug
Jun 9, 2021
fe97259
debug
Jun 9, 2021
5a93c67
origBatchIdx -> currentBatchIdx. Doesn't crash but bad results
Jun 9, 2021
fef7202
batch-beam -> beam-batch
Jun 11, 2021
f025188
debug
Jun 11, 2021
4999821
don't transpose lastIndices. Works for lsh
Jun 11, 2021
700dc7f
don't transpose lastIndices. Works for lsh & sl
Jun 11, 2021
8649034
no need to broadcast
Jun 11, 2021
cc29593
incorrect dimension order
Jun 14, 2021
dffbb47
rename broadcast -> createCachedTensors
Jun 14, 2021
8c04f66
reverse batch beam argument order
Jun 15, 2021
5b7b1f7
no need for args in getIndicesExpr(). Deleted debugging
Jun 15, 2021
5362c2c
don't define BLAS_FOUND
Jun 15, 2021
82fa059
'use' variables
Jun 15, 2021
7e6ec58
delete variables altogether
Jun 15, 2021
488a532
get lemma size from vocab class
Jun 15, 2021
6981b21
Merge branch 'hihoan/lsh7' of vs-ssh.visualstudio.com:v3/machinetrans…
Jun 15, 2021
395a4f9
init vector
Jun 16, 2021
8925541
lemma Et is optional
Jun 16, 2021
9b4a845
clean up bias
Jun 16, 2021
85eb6ad
update sentencepiece pointer to version with case-awareness
emjotde Jun 16, 2021
a332e55
debug
Jun 16, 2021
cd292d3
changes for review
Jun 18, 2021
fc0f41f
Merged PR 19597: Enable mpi wrapper to use size larger than MAX_INT
emjotde Jun 28, 2021
8daa0a4
fix compilation errors due to narrow conversion
emjotde Jun 29, 2021
24c644b
pass shortlist regression tests
Jun 29, 2021
7e2fce8
Merge branch 'hihoan/lsh7' of vs-ssh.visualstudio.com:v3/machinetrans…
Jun 29, 2021
64e787a
Merge branch 'master' into hihoan/lsh7
emjotde Jun 29, 2021
ff8af52
lock index before creation
Jul 1, 2021
bd1f1ee
marcin's review changes
Jul 2, 2021
9acf27d
credit SLIDE
Jul 2, 2021
5ad0edf
remove todo
Jul 3, 2021
4ace42f
paper
Jul 3, 2021
8bfa6a4
Merge branch 'hihoan/lsh7' of vs-ssh.visualstudio.com:v3/machinetrans…
emjotde Jul 3, 2021
9772aa2
remaining comments
emjotde Jul 3, 2021
d6c09b2
Merged PR 19409: Unify LSH and short list interface
Jul 3, 2021
35c822e
Merged PR 19685: Marianize LSH as operators for mmapping and use in Q…
emjotde Jul 9, 2021
3a478fc
update version and changelog
emjotde Jul 9, 2021
7e6ea51
silence unreferenced formal parameter warning on windows
emjotde Jul 10, 2021
42f0b8b
Binary shortlist (#856)
qianqianzhu Jul 11, 2021
8e88071
Merged PR 19842: Adapt LSH to work with Leaf
emjotde Jul 16, 2021
056c4be
Merged PR 19860: Case augmented data, if not using factored vocab mus…
rjai Jul 17, 2021
f6cb1b5
Merged PR 19864: add bias if it exists
Jul 21, 2021
b83b06f
Merged PR 19914: Fix Windows Azure Pipelines
Jul 22, 2021
6b568f4
Merged PR 19904: Update instructions for building on Windows
Jul 22, 2021
b653db0
Merged PR 19910: Fix training/scoring error with FSM
emjotde Jul 22, 2021
4ff2ef1
Merged PR 19761: Expose SPM Interface from Marian
rjai Jul 30, 2021
d124ca9
allow float32 conversion in QS interface
emjotde Aug 4, 2021
e025bfb
Merged PR 20070: Run regression tests in Azure Pipelines
Aug 6, 2021
6652b31
Merged PR 20560: Update SPM in Marian
rjai Sep 2, 2021
8d0a3c0
Add --allow-unauthenticated when installing CUDA (#878)
snukky Sep 7, 2021
4dd30b5
Factor concatenation improvements and documentation (#748)
kpu Sep 8, 2021
8470c16
Merged PR 20230: Add option for running regression tests only in Azur…
Sep 16, 2021
aa58ba8
Merged PR 20593: Fix and update Azure pipelines
Sep 20, 2021
d796a3c
Merged PR 20839: Do not ignore ignoreEOS for spm decoding
emjotde Sep 28, 2021
03fe175
Merged PR 20879: Adjustable ffn width and depth in transformer decoder
emjotde Sep 28, 2021
12a1bfa
Remove Ubuntu 16.04 from GitHub workflows (#879)
snukky Oct 11, 2021
2d79ad0
Merged PR 20933: beam & batch works for n on-factored models
Oct 13, 2021
7f06f3c
Merged PR 21166: Keep building on macOS-10.15
Oct 26, 2021
1404201
Merged PR 21151: Cleaning up fp16 behavior
emjotde Oct 26, 2021
2bdfbd3
Update badges in README.md
snukky Nov 21, 2021
c85d060
Merged PR 20729: Add top-k sampling
emjotde Nov 22, 2021
3b4e943
Added pragma to ignore unused-private-field error on elementType_ whi…
dameikle Nov 22, 2021
3d15cd3
Update submodule regression-tests
snukky Nov 22, 2021
1adf80b
Task alias validation during training mode (#886)
XapaJIaMnu Nov 22, 2021
ab6b826
Add GCC 11 support (#888)
XapaJIaMnu Nov 23, 2021
8b8d1b1
Merged PR 21553: Parallelize data reading for training
emjotde Nov 25, 2021
bbc673c
update CHANGELOG and VERSION
emjotde Nov 25, 2021
c64cb29
Constrain version of mistune to before v2 in GitHub CI Documentation …
graemenail Dec 6, 2021
e8ea37c
Merged PR 21648: Allow for dynamic gradient scaling to fade out after…
emjotde Dec 6, 2021
cd9afea
Documentation about how to write code documentation (#891)
qianqianzhu Dec 7, 2021
e8a1a25
Fix AVX2+ detection on Mac (#895)
XapaJIaMnu Dec 7, 2021
e26e5b6
Use apple accelerate on MacOs by default (#897)
XapaJIaMnu Dec 16, 2021
c84599d
Update VERSION
snukky Dec 16, 2021
b29cc07
Scorer model loading (#860)
graemenail Jan 18, 2022
b64e258
Update VERSION
snukky Jan 18, 2022
894a07a
Improve checks on transformer cache (#881)
graemenail Jan 24, 2022
3b458b0
Update VERSION
snukky Jan 24, 2022
71b5454
Layer documentation (#892)
qianqianzhu Jan 26, 2022
07c39c7
Cherry picked cleaning/refeactoring patches (#905)
snukky Jan 28, 2022
266b931
Update list of contributors (#906)
snukky Jan 30, 2022
8da539e
merged with master
emjotde Feb 6, 2022
3cf9e83
resolve conflicts
emjotde Feb 6, 2022
aafe8fb
update regression tests pointer
emjotde Feb 7, 2022
a365bb5
fix server behaviour
emjotde Feb 7, 2022
05ba9e4
add -DDETERMINISTIC=ON/OFF flag (#912)
emjotde Feb 8, 2022
8e659bb
Document Structure (#910)
graemenail Feb 8, 2022
f00d062
update VERSION and CHANGELOG - Release 1.11.0
emjotde Feb 8, 2022
bcf29b8
Update acknowledgements (#914)
graemenail Feb 9, 2022
b976458
Update release workflow (#915)
snukky Feb 9, 2022
73f1899
Add dependabot for git submodules (#916)
snukky Feb 10, 2022
a492bc5
Bump regression-tests from `0716f4e` to `f7971b7` (#918)
dependabot[bot] Feb 10, 2022
4d44627
PyYaml safe_load instead of load (#913)
graemenail Feb 10, 2022
17e55f5
Update VERSION
snukky Feb 10, 2022
8fd553e
Bump examples from `6d5921c` to `0ca966e` (#919)
dependabot[bot] Feb 10, 2022
e6dbacb
Merged PR 22490: Faster LSH top-k for CPU
emjotde Feb 10, 2022
b3feecc
Merged PR 22483: Make C++17 the official standard for Marian
emjotde Feb 10, 2022
3b21ff3
update VERSION and CHANGELOG
emjotde Feb 10, 2022
4b51dcb
Merged PR 22524: Optimize guided alignment training speed via sparse …
emjotde Feb 11, 2022
b0275e7
merge with internal master
emjotde Feb 11, 2022
b8bf086
move regression-tests pointer
emjotde Feb 11, 2022
8a9580b
update the intgemm version to upstream (#924)
XapaJIaMnu Feb 15, 2022
58c4576
Bump regression-tests from `da95717` to `88e6382` (#923)
dependabot[bot] Feb 15, 2022
601c9ac
Detect fortran_order in npz (#911)
graemenail Feb 15, 2022
adaaf08
better error message
emjotde Feb 16, 2022
310d2f4
Merged PR 22939: Fix case augmentation with multi-threaded reading
emjotde Mar 7, 2022
16bfa0c
Merged PR 23094: Adapt --cost-scaling to more stable setting
emjotde Mar 16, 2022
c809843
Bump examples from `6d5921c` to `29f4f7c` (#928)
dependabot[bot] Mar 22, 2022
75a7a1d
Bump regression-tests from `88e6382` to `4fa9ff5` (#929)
dependabot[bot] Mar 22, 2022
78bef7a
Bump src/3rd_party/sentencepiece from `c307b87` to `5312a30` (#927)
dependabot[bot] Mar 22, 2022
23c36ec
Fixed fp16 training/inference with factors-combine concat (#926)
arturnn Mar 22, 2022
d5c7372
Merged PR 23407: Fix incorrect/missing gradient accumulation for affi…
emjotde Apr 8, 2022
1e4e101
Merged PR 23415: Set Windows image back to windows-2019
Apr 8, 2022
1a74358
Merged PR 23429: Small fixes around fp16 training and batch fitting
emjotde Apr 11, 2022
e4f3d0f
add fallback option for sampling, for back-compat
emjotde May 9, 2022
e0e3287
Merged PR 23840: Update CUDA installation script for Ubuntu
May 12, 2022
704a323
Merged PR 22799: Running regression tests on Azure Pipelines
May 13, 2022
95720ae
Update NVIDIA CUDA signing key for CI; fix for building docs (#932)
graemenail May 18, 2022
f3e1efe
merge with internal master
emjotde May 26, 2022
042ed8f
Merged PR 24072: Revert changes to transformer caching
emjotde May 30, 2022
5df240f
Update status badges (#935)
snukky May 31, 2022
c5081df
Merged PR 24111: Remove external reference to Docker images
May 31, 2022
e27da62
Directory listing in Ubuntu and macOS workflows (#938)
graemenail Jun 6, 2022
a90950e
Merged PR 25154: Add model shapes flag to model_info.py script
alexandremuzio Aug 10, 2022
5d466bc
Merged PR 25507: Upgrade Azure Pipelines to ubuntu-20.04
Sep 2, 2022
f9a1ed1
Add a workflow compiling Marian using clang-14 (#940)
snukky Sep 2, 2022
6250cd8
Fixed some warnings on clang 15 that are promoted into errors (#936)
KOLANICH Sep 2, 2022
3bd281c
Fix clang 13.0.1 (#939)
XapaJIaMnu Sep 2, 2022
650cf19
Update Catch2 from 2.10.1 to 2.13.9 (#941)
graemenail Sep 2, 2022
bf5eafa
Bump src/3rd_party/intgemm from `a05a2e5` to `0eda93a` (#933)
dependabot[bot] Sep 2, 2022
7d65460
Fix guaranteed `YAML::InvalidNode` when compiled with `COMPILE_CPU=Of…
jelmervdl Sep 2, 2022
0afe247
Upgrade workflows to ubuntu-20.04 and macos-12 (#962)
snukky Sep 2, 2022
347ab4d
Upgrade dependencies in the documentation framework (#965)
snukky Sep 5, 2022
b6d0667
Bump regression-tests from `4fa9ff5` to `92e116e` (#964)
dependabot[bot] Sep 5, 2022
a5223e2
Bump examples from `29f4f7c` to `25e8438` (#963)
dependabot[bot] Sep 5, 2022
6b41df2
Version 1.11.8
snukky Sep 5, 2022
a47912d
Merged PR 25518: Upgrade Azure Pipelines to macos-12
Sep 15, 2022
6f7766f
Merged PR 25465: Choose top checkpoints from train.log for averaging
Sep 15, 2022
e13053a
Merged PR 25698: Install Python 3.8 on GPU pool
Sep 16, 2022
7696479
Merged PR 23767: More principled sampling and force-decoding
emjotde Sep 16, 2022
7d2045a
Merged PR 25686: Loading checkpoints from main node only via MPI
emjotde Sep 21, 2022
cfc33f5
only use tcmalloc_minimal
emjotde Sep 22, 2022
1f2929d
Merged PR 25733: Fused inplace ReLU and Dropout in transformer FFN layer
emjotde Sep 26, 2022
2cd3055
Merged PR 25836: Check via hashing if re-syncing in local mode is req…
emjotde Sep 27, 2022
2c55cdb
Merged PR 25889: Fixes bad memory access problem in hashing
emjotde Sep 29, 2022
1e92cff
Merged PR 25919: Sync with public master - no review required
emjotde Oct 4, 2022
da6e30b
merge with internal master
emjotde Oct 4, 2022
4d3702c
Merged PR 25950: Add missing defaults for concatenated factors
emjotde Oct 6, 2022
a6de1b7
Merged PR 26271: Update CI pipeline triggers
Nov 1, 2022
be1ee3f
Merged PR 26318: Fix incorrect envvar name in Azure Pipeline
Nov 1, 2022
cda2f21
Temporarily download MKL tarball from a mirror server (#972)
snukky Nov 2, 2022
07a2ac8
best-deep alias broken (#968)
XapaJIaMnu Nov 2, 2022
4187aab
Bump regression-tests from `92e116e` to `494d6de` (#973)
dependabot[bot] Nov 19, 2022
3634964
Bump src/3rd_party/sentencepiece from `31ac8e8` to `8dc9172` (#970)
dependabot[bot] Nov 19, 2022
c79dc80
Merged PR 26617: Update regression-tests & fix CI pipelines
Nov 20, 2022
b6581c4
Merged PR 26667: Update examples submodule to fix vulnerability issues
Nov 23, 2022
d5569ce
Bump regression-tests from `494d6de` to `488d454` (#974)
dependabot[bot] Nov 29, 2022
3c2a432
Bump examples from `25e8438` to `58f48a0` (#975)
dependabot[bot] Nov 29, 2022
b7205fc
Merged PR 25220: Add extra model information to model_info.py script
alexandremuzio Nov 30, 2022
ee50d4a
Merged PR 27051: Add an option for completely resetting validation me…
Dec 20, 2022
4f145c4
Merged PR 26311: [FSM] make model loading lock non-static
Feb 10, 2023
9ad5203
Merged PR 26476: Sanitize guided-alignment with case-augmentation (st…
emjotde Feb 11, 2023
031dbb3
Merged PR 27804: Fallback to old LSH code for MSVC due to bad loop un…
emjotde Feb 13, 2023
4ffd292
Merge branch 'master' into pmaster
emjotde Feb 20, 2023
9871c90
Merged PR 27999: Update internal master to public master
emjotde Feb 20, 2023
65bf82f
version 1.12.0 (#980)
emjotde Feb 21, 2023
ab1b8da
Merge internal master with public master
emjotde Feb 22, 2023
efcd3da
Merged PR 28059: Add missing default for factors
emjotde Feb 23, 2023
a23cc77
Merged PR 27976: Introduce new layer framework into master
emjotde Feb 27, 2023
d225c24
Merged PR 28128: Comet scoring and training with new layer framework
emjotde Mar 1, 2023
30f41da
Merged PR 28460: Revert "Merged PR 26311: [FSM] make model loading lo…
fsigalov Mar 16, 2023
26b178c
Merged PR 28179: comet2marian.py: download comet models automatically.
Mar 17, 2023
cd4d1ec
Merged PR 28674: Add --early-stopping-epsilon param
Mar 30, 2023
a421476
Merged PR 28502: Comet2Marian: add --spm argument to download vocabul…
Apr 13, 2023
cd78417
Bump examples from `58f48a0` to `6c40475` (#987)
dependabot[bot] Apr 14, 2023
1334fa5
Bump regression-tests from `2a8bed3` to `89ce02e` (#984)
dependabot[bot] Apr 14, 2023
8bf101c
Fix include path typo in onnx exporter (#978)
angrypie Apr 14, 2023
3daf4ee
quote CPUINFO in cmake (#983)
josharian Apr 15, 2023
d054dc8
Bump src/3rd_party/fbgemm from `6f45243` to `0e33146` (#995)
dependabot[bot] Jun 19, 2023
02678ef
Merged PR 29868: Add option to replace current parameters with smooth…
emjotde Jun 19, 2023
7425c02
Merged PR 30009: Divergence detection and fallback to fp32 if trainin…
emjotde Jun 27, 2023
ea8a2db
Merged PR 30038: Add a comment that automatic builds are disabled
Jun 28, 2023
0fa11f5
Merged PR 30034: Automatically create marian-YYYY-MM-DD-GIT_REV.tgz
emjotde Jun 28, 2023
0df870c
Merged PR 28958: LSH for GPU
Jun 29, 2023
cc66cf6
Merged PR 29966: More metrics in Marian and MBR scripts
emjotde Jun 29, 2023
d1d10a4
Merged PR 30079: Fixes and extends unit test for layer norm
emjotde Jul 1, 2023
bd63cce
Merged PR 28078: Various small improvements
emjotde Jul 3, 2023
a5b50f2
Merged PR 30282: Fix parameter name for norms in new layer framework.
emjotde Jul 16, 2023
c8f1e03
Merged PR 30198: [quicksand] cache YAML configs
Jul 17, 2023
c83d47f
Merged PR 30283: Save full checkpoints at saving intervals (with iter…
emjotde Jul 22, 2023
09cb320
Bump src/3rd_party/sentencepiece from `8dc9172` to `fb6f8e4` (#1000)
dependabot[bot] Jul 24, 2023
9af4740
Merged PR 30415: Fix macOS clang builds
Jul 24, 2023
b67489e
Merged PR 30419: Fix Python modules in GPU regression tests
Jul 25, 2023
68cc88f
Fix macOS actions (#1002)
snukky Jul 25, 2023
717d351
Merged PR 30406: More general fallbacks for diverged training
emjotde Jul 26, 2023
e383583
Merged PR 30482: Fixes for backward compatibility in fine-tuning
Jul 27, 2023
3bd25dd
Merged PR 30516: Make sure that loss is finite when checking for dive…
emjotde Jul 31, 2023
60aa66b
Merged PR 30704: Merge with public master from 20230814
emjotde Aug 14, 2023
8cbd2df
Merge branch 'master' into pmaster
emjotde Aug 14, 2023
3f93e65
don't include nppdefs.h. Problematic on some machines (#1004)
hieuhoang Aug 15, 2023
961a728
Add an option to not encode sentencepiece during training/decoding al…
XapaJIaMnu Aug 17, 2023
3b0594c
Resolved conflicts
samirsalman Aug 24, 2023
56abb91
fix syntax error
samirsalman Aug 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
batch-beam -> beam-batch
Hieu Hoang committed Jun 11, 2021
commit fef7202bc809ac1dd31a89a7f6697e110b681d95
49 changes: 28 additions & 21 deletions src/data/shortlist.cpp
Original file line number Diff line number Diff line change
@@ -129,13 +129,18 @@ LSHShortlist::LSHShortlist(int k, int nbits)
//#define BLAS_FOUND 1

WordIndex LSHShortlist::reverseMap(int batchIdx, int beamIdx, int idx) const {
std::cerr << "\nbatchIdx=" << batchIdx << " beamIdx=" << beamIdx << " idx=" << idx << std::endl;
std::cerr << "\nbatchIdx=" << batchIdx
<< " beamIdx=" << beamIdx
<< " idx=" << idx
<< " k_=" << k_
<< std::endl;
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
int currBatchSize = indicesExpr_->shape()[0];
int currBeamSize = indicesExpr_->shape()[1];
int currBeamSize = indicesExpr_->shape()[0];
int currBatchSize = indicesExpr_->shape()[1];
std::cerr << "currBatchSize=" << currBatchSize << " currBeamSize=" << currBeamSize << std::endl;
std::cerr << "indices_=" << indices_.size() << std::endl;
idx = (k_ * currBeamSize) * batchIdx + k_ * beamIdx + idx;
idx = (k_ * currBatchSize * beamIdx) + (k_ * batchIdx) + idx;
//idx = (k_ * currBeamSize * batchIdx) + (k_ * beamIdx) + idx;
std::cerr << "idx=" << idx << std::endl;
assert(idx < indices_.size());
return indices_[idx];
@@ -152,13 +157,16 @@ WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
}

Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
//std::cerr << "batchSize=" << batchSize << " currBeamSize=" << currBeamSize << std::endl;
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << " " << indicesExpr_->val() << std::endl;
assert(indicesExpr_->shape()[0] == batchSize);
assert(indicesExpr_->shape()[1] == currBeamSize);
return indicesExpr_;
std::cerr << "batchSize=" << batchSize << " currBeamSize=" << currBeamSize << std::endl;
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << " " << indicesExpr_->val() << std::endl;
assert(indicesExpr_->shape()[0] == currBeamSize);
assert(indicesExpr_->shape()[1] == batchSize);
Expr ret = transpose(indicesExpr_, {1, 0, 2});
return ret;
}

#define BLAS_FOUND 1

void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
#if BLAS_FOUND
static int c = 0;
@@ -186,6 +194,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
index_->add( vRows, values->val()->data<float>());
}

std::cerr << "query=" << query->shape() << std::endl;
int qRows = query->shape().elements() / dim;
std::vector<float> distances(qRows * k_);
std::vector<faiss::Index::idx_t> ids(qRows * k_);
@@ -207,8 +216,8 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
out->val()->set(indices_);
};

Shape kShape({batchSize, currBeamSize, k_});
//std::cerr << "kShape=" << kShape << std::endl;
Shape kShape({currBeamSize, batchSize, k_});
std::cerr << "kShape=" << kShape << std::endl;

indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
@@ -227,9 +236,9 @@ void LSHShortlist::broadcast(Expr weights,
Expr lemmaEt,
Expr indicesExprBC,
int k) {
//std::cerr << "indicesExprBC.0=" << indicesExprBC->shape() << std::endl;
int batchSize = indicesExprBC->shape()[0];
int currBeamSize = indicesExprBC->shape()[1];
std::cerr << "indicesExprBC.0=" << indicesExprBC->shape() << std::endl;
int currBeamSize = indicesExprBC->shape()[0];
int batchSize = indicesExprBC->shape()[1];
//int numHypos = batchSize * currBeamSize;
//std::cerr << "batchSize=" << batchSize << std::endl;
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
@@ -239,14 +248,12 @@ void LSHShortlist::broadcast(Expr weights,
indicesExprBC = reshape(indicesExprBC, {indicesExprBC->shape().elements()});
//std::cerr << "indicesExprBC.2=" << indicesExprBC->shape() << std::endl;

//std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
//std::cerr << "weights=" << weights->shape() << std::endl;
std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
std::cerr << "weights=" << weights->shape() << std::endl;
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprBC);
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]});
//std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = transpose(cachedShortWt_, {1, 0, 2, 3});
//std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]});
std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;

if (b) {
ABORT("Bias not yet tested");
31 changes: 25 additions & 6 deletions src/layers/logits.cpp
Original file line number Diff line number Diff line change
@@ -8,6 +8,15 @@ Logits::Logits(Expr logits)
: Logits(New<RationalLoss>(logits, nullptr)) {
} // single-output constructor from Expr only (RationalLoss has no count)

Logits::Logits(Ptr<RationalLoss> logits) { // single-output constructor
logits_.push_back(logits);
}

Logits::Logits(std::vector<Ptr<RationalLoss>>&& logits,
Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
: logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {
}

Ptr<ExpressionGraph> Logits::graph() const {
ABORT_IF(logits_.empty(), "Empty logits object??");
return logits_.front()->loss()->graph();
@@ -53,6 +62,7 @@ Expr Logits::applyLossFunction(
auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
// clang-format on
@@ -85,12 +95,14 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");

auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
std::cerr << "sel.1=" << sel->shape() << std::endl;

// normalize for decoding:
// - all secondary factors: subtract their max
// - lemma: add all maxes of applicable factors
if(groupIndex > 0) {
sel = sel - max(sel, -1);
std::cerr << "sel.2=" << sel->shape() << std::endl;
} else {
auto numGroups = getNumFactorGroups();
for(size_t g = 1; g < numGroups; g++) {
@@ -101,7 +113,7 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
factorMasks = constant(getFactorMasks(g, std::vector<WordIndex>()));
}
else {
//std::cerr << "sel=" << sel->shape() << std::endl;
std::cerr << "sel.3=" << sel->shape() << std::endl;
auto forward = [this, g](Expr out, const std::vector<Expr>& inputs) {
Expr lastIndices = inputs[0];
std::vector<float> masks = getFactorMasksMultiDim(g, lastIndices);
@@ -111,20 +123,27 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
int currBeamSize = sel->shape()[0];
int batchSize = sel->shape()[2];
Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize);
//std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward);
//std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
factorMasks = transpose(factorMasks, {1, 0, 2});
//std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl;
std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl;

const Shape &s = factorMasks->shape();
factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]});
//std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
}
factorMaxima = cast(factorMaxima, sel->value_type());
std::cerr << "factorMaxima=" << factorMaxima->shape() << std::endl;
factorMasks = cast(factorMasks, sel->value_type());
sel = sel + factorMaxima * factorMasks; // those lemmas that don't have a factor
std::cerr << "factorMasks.4=" << factorMasks->shape() << std::endl;

Expr tmp = factorMaxima * factorMasks;
std::cerr << "tmp=" << tmp->shape() << std::endl;
std::cerr << "sel.4=" << sel->shape() << std::endl;
sel = sel + tmp; // those lemmas that don't have a factor
// get multiplied with 0
std::cerr << "sel.5=" << sel->shape() << std::endl;
}
}

11 changes: 4 additions & 7 deletions src/layers/logits.h
Original file line number Diff line number Diff line change
@@ -17,14 +17,11 @@ class RationalLoss;
class Logits {
public:
Logits() {}
explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
logits_.push_back(logits);
}
explicit Logits(
Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
explicit Logits(Ptr<RationalLoss> logits); // single-output constructor
explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
Logits(std::vector<Ptr<RationalLoss>>&& logits,
Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
: logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
Ptr<FactoredVocab> embeddingFactorMapping); // factored-output constructor

Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
Expr getFactoredLogits(
size_t groupIndex,
26 changes: 15 additions & 11 deletions src/layers/output.cpp
Original file line number Diff line number Diff line change
@@ -63,9 +63,6 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
};

auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) {
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
x = transpose(x, {0, 2, 1, 3});
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
Expr ret = bdot(x, W, false, true);
@@ -174,29 +171,35 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
// matrix
Expr factorLogits;
if(g == 0 && shortlist_) {
//std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
//std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
Expr tmp = transpose(input1, {0, 2, 1, 3});
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
factorLogits = affineShortlist(
input1,
tmp,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
//std::cerr << "affineShortlist.factorLogits.1=" << factorLogits->shape() << std::endl;
std::cerr << "affineShortlist.factorLogits.1=" << factorLogits->shape() << std::endl;
factorLogits = transpose(factorLogits, {0, 2, 1, 3});
//std::cerr << "affineShortlist.factorLogits.2=" << factorLogits->shape() << std::endl;
std::cerr << "affineShortlist.factorLogits.2=" << factorLogits->shape() << std::endl;
}
else {
//std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
//std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl;
std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
std::cerr << "affineOrDot.factorWt.1=" << factorWt->shape() << std::endl;
//factorWt = transpose(factorWt, {1, 0, 2, 3});
//std::cerr << "affineOrDot.factorWt.2=" << factorWt->shape() << std::endl;
factorLogits = affineOrDot(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
//std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl;
std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl;
}
std::cerr << std::endl;

// optionally add lemma-dependent bias
if(Plemma) { // [B... x U0]
@@ -210,6 +213,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
factorLogits = factorLogits + b;
}
//std::cerr << "factorLogits=" << factorLogits->shape() << std::endl;
allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
// optionally add a soft embedding of lemma back to create some lemma dependency
// @TODO: if this works, move it into lazyConstruct